2025-10-16 13:28:24 +08:00

106 lines
2.5 KiB
Go

package main
import (
"flag"
"log"
"os"
"strings"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gen"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
func main() {
dsn := flag.String("dsn", "", "consult[https://gorm.io/docs/connecting_to_the_database.html]")
db := flag.String("db", "", "input mysql|sqlite")
tables := flag.String("tables", "", "enter the required data table or leave it blank")
flag.Parse()
if *dsn == "" {
log.Fatal("dsn cannot be empty, please provide a valid dsn value.")
}
if *db == "" {
*db = "mysql"
}
// 初始化 GORM 数据库连接
var gormDb *gorm.DB
var err error
var daoPath string
var modelPath string
switch *db {
case "mysql":
gormDb, err = gorm.Open(mysql.Open(*dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
})
daoPath = "/internal/repository/mysql/dao"
modelPath = "/internal/repository/mysql/model"
case "sqlite":
gormDb, err = gorm.Open(sqlite.Open(*dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
})
daoPath = "/internal/repository/sqlite/dao"
modelPath = "/internal/repository/sqlite/model"
default:
log.Fatalf("Unsupported database type: %s. Supported types: mysql, sqlite", *db)
}
if err != nil {
log.Fatalf("failed to connect database: %s", err.Error())
}
// 获取当前工作目录
wd, err := os.Getwd()
if err != nil {
log.Fatalf("Error getting working directory:%s", err.Error())
}
// 初始化代码生成器
g := gen.NewGenerator(gen.Config{
OutPath: wd + daoPath, // 指定生成代码的输出目录
ModelPkgPath: wd + modelPath, // 指定模型文件存放的包路径
Mode: gen.WithoutContext | gen.WithDefaultQuery,
})
// 使用已连接的数据库
g.UseDB(gormDb)
// 生成表的模型
tableMaps := strings.Split(*tables, ",")
if len(tableMaps) == 0 || *tables == "" {
g.ApplyBasic(
g.GenerateAllTable()...,
)
} else {
tableList := make([]string, 0, len(tableMaps))
for _, tableName := range tableMaps {
_tableName := strings.TrimSpace(string(tableName)) // trim leading and trailing space in tableName
if _tableName == "" { // skip empty tableName
continue
}
tableList = append(tableList, _tableName)
}
// Execute some data table tasks
models := make([]interface{}, len(tableList))
for i, tableName := range tableList {
models[i] = g.GenerateModel(tableName)
}
g.ApplyBasic(models...)
}
// 生成代码
g.Execute()
}