106 lines
2.5 KiB
Go
106 lines
2.5 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gen"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
func main() {
|
|
dsn := flag.String("dsn", "", "consult[https://gorm.io/docs/connecting_to_the_database.html]")
|
|
db := flag.String("db", "", "input mysql|sqlite")
|
|
tables := flag.String("tables", "", "enter the required data table or leave it blank")
|
|
flag.Parse()
|
|
|
|
if *dsn == "" {
|
|
log.Fatal("dsn cannot be empty, please provide a valid dsn value.")
|
|
}
|
|
|
|
if *db == "" {
|
|
*db = "mysql"
|
|
}
|
|
|
|
// 初始化 GORM 数据库连接
|
|
var gormDb *gorm.DB
|
|
var err error
|
|
var daoPath string
|
|
var modelPath string
|
|
|
|
switch *db {
|
|
case "mysql":
|
|
gormDb, err = gorm.Open(mysql.Open(*dsn), &gorm.Config{
|
|
NamingStrategy: schema.NamingStrategy{
|
|
SingularTable: true,
|
|
},
|
|
})
|
|
|
|
daoPath = "/internal/repository/mysql/dao"
|
|
modelPath = "/internal/repository/mysql/model"
|
|
case "sqlite":
|
|
gormDb, err = gorm.Open(sqlite.Open(*dsn), &gorm.Config{
|
|
NamingStrategy: schema.NamingStrategy{
|
|
SingularTable: true,
|
|
},
|
|
})
|
|
|
|
daoPath = "/internal/repository/sqlite/dao"
|
|
modelPath = "/internal/repository/sqlite/model"
|
|
default:
|
|
log.Fatalf("Unsupported database type: %s. Supported types: mysql, sqlite", *db)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Fatalf("failed to connect database: %s", err.Error())
|
|
}
|
|
|
|
// 获取当前工作目录
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
log.Fatalf("Error getting working directory:%s", err.Error())
|
|
}
|
|
|
|
// 初始化代码生成器
|
|
g := gen.NewGenerator(gen.Config{
|
|
OutPath: wd + daoPath, // 指定生成代码的输出目录
|
|
ModelPkgPath: wd + modelPath, // 指定模型文件存放的包路径
|
|
Mode: gen.WithoutContext | gen.WithDefaultQuery,
|
|
})
|
|
|
|
// 使用已连接的数据库
|
|
g.UseDB(gormDb)
|
|
|
|
// 生成表的模型
|
|
tableMaps := strings.Split(*tables, ",")
|
|
if len(tableMaps) == 0 || *tables == "" {
|
|
g.ApplyBasic(
|
|
g.GenerateAllTable()...,
|
|
)
|
|
} else {
|
|
tableList := make([]string, 0, len(tableMaps))
|
|
for _, tableName := range tableMaps {
|
|
_tableName := strings.TrimSpace(string(tableName)) // trim leading and trailing space in tableName
|
|
if _tableName == "" { // skip empty tableName
|
|
continue
|
|
}
|
|
tableList = append(tableList, _tableName)
|
|
}
|
|
|
|
// Execute some data table tasks
|
|
models := make([]interface{}, len(tableList))
|
|
for i, tableName := range tableList {
|
|
models[i] = g.GenerateModel(tableName)
|
|
}
|
|
g.ApplyBasic(models...)
|
|
}
|
|
|
|
// 生成代码
|
|
g.Execute()
|
|
}
|