| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- package ams_ast
- import (
- "fmt"
- "gorm.io/gorm"
- "os"
- "path/filepath"
- "strings"
- "text/template"
- )
- // 表信息结构体
- type TableInfo struct {
- Module string
- TableName string
- StructName string
- Fields []*FieldInfo
- HasDecimal bool
- HasTime bool
- HasJson bool
- }
- // 字段信息结构体
- type FieldInfo struct {
- Name string
- Type string
- GormTag string
- JsonTag string
- Comment string
- DefaultValue string
- IsPrimary bool // 是否主键
- IsNullable bool // 是否可为空
- IsIndex bool // 是否索引
- IsUnique bool // 是否唯一约束索引
- UniqueName string // 约束名
- }
- // 获取表的列信息
- type Column struct {
- ColumnName string `gorm:"column:COLUMN_NAME"`
- DataType string `gorm:"column:DATA_TYPE"`
- ColumnType string `gorm:"column:COLUMN_TYPE"`
- ColumnKey string `gorm:"column:COLUMN_KEY"`
- IsNullable string `gorm:"column:IS_NULLABLE"`
- ColumnDefault string `gorm:"column:COLUMN_DEFAULT"`
- ColumnComment string `gorm:"column:COLUMN_COMMENT"`
- NumericPrecision *int64 `gorm:"column:NUMERIC_PRECISION"`
- NumericScale *int64 `gorm:"column:NUMERIC_SCALE"`
- }
- // 唯一约束信息
- type Constraint struct {
- ConstraintName string `gorm:"column:CONSTRAINT_NAME"`
- ColumnName string `gorm:"column:COLUMN_NAME"`
- ReferencedTableName string `gorm:"column:REFERENCED_TABLE_NAME"`
- ReferencedColumnName string `gorm:"column:REFERENCED_COLUMN_NAME"`
- }
- // 索引信息
- type IndexInfo struct {
- IndexName string `gorm:"column:INDEX_NAME"`
- ColumnName string `gorm:"column:COLUMN_NAME"`
- NonUnique int `gorm:"column:NON_UNIQUE"`
- }
- // 模型注册器
- type ModelRegister struct {
- DbAlias string // 数据库别名
- ModelTemplateFile string `json:"modelTemplateFile"`
- TargetDir string `json:"targetDir"` // 文件生成目标文件夹
- TargetFile string `json:"targetFile"` // 模型文件
- TableName string `json:"tableName"`
- StructName string `json:"structName"` // 生成的结构体名称
- BaseAst
- }
- // 数据库 表名
- func BuildModelRegister(dbAlias, dbName, tableName string) *ModelRegister {
- register := &ModelRegister{
- DbAlias: dbAlias, // 数据库别名 biz
- TableName: tableName, // 表名称 demo_user
- StructName: CapitalizeOrLower(toCamelCase(tableName)), // 首字母大写驼峰写法 DemoUser
- }
- register.ModelTemplateFile = register.modelTemplateFile() // 模版文件地址
- register.TargetFile = register.modelTargetFile(tableName) // 模型目标文件
- return register
- }
- // 生成表结构
- func (s *ModelRegister) Initialize(db *gorm.DB) error {
- var files *template.Template
- files, err := template.ParseFiles(s.ModelTemplateFile)
- if err != nil {
- return err
- }
- //slog.Infof("模型模版文件:%s", s.ModelTemplateFile)
- err = os.MkdirAll(filepath.Dir(s.TargetFile), os.ModePerm)
- if err != nil {
- return err
- }
- if FileExists(s.TargetFile) {
- err = os.Remove(s.TargetFile)
- if err != nil {
- return err
- }
- }
- var file *os.File
- file, err = os.Create(s.TargetFile)
- if err != nil {
- return err
- }
- data, err := s.getTableInfo(db)
- if err != nil {
- return err
- }
- err = files.Execute(file, data)
- _ = file.Close()
- if err != nil {
- return err
- }
- return nil
- }
- // 回滚
- func (s *ModelRegister) RollbackModel() error {
- if FileExists(s.TargetFile) {
- err := os.Remove(s.TargetFile)
- if err != nil {
- return err
- }
- }
- return nil
- }
- // 模板文件地址
- func (s *ModelRegister) modelTemplateFile() string {
- return filepath.Join(template_root, "model/model.go.tpl")
- }
- // 生成模型文件路径
- func (s *ModelRegister) modelTargetFile(tableName string) string {
- return filepath.Join(s.modelTargetDir(), toSnakeCase(tableName)+".go")
- }
- // 模型目标文件夹 -- go_server/server/model/server_model_biz/DbAlias
- func (s *ModelRegister) modelTargetDir() string {
- return filepath.Join(modelTargetPath, server_model_biz, s.DbAlias)
- }
- // 获取表信息
- func (s *ModelRegister) getTableInfo(db *gorm.DB) (*TableInfo, error) {
- tableInfo := &TableInfo{
- Module: s.DbAlias,
- TableName: s.TableName,
- StructName: s.StructName,
- Fields: make([]*FieldInfo, 0),
- HasDecimal: false,
- HasTime: false,
- HasJson: false,
- }
- cols, err := s.getTableCols(db)
- if err != nil {
- return nil, err
- }
- constraint, err := s.getTableConstraint(db)
- if err != nil {
- return nil, err
- }
- indexes, err := s.getTableIndexInfo(db)
- if err != nil {
- return nil, err
- }
- for _, col := range cols {
- // 是否为Primary index
- // 获取唯一约束信息
- uniqueName, isUnique := s.colUniqueInfo(constraint, col.ColumnName)
- isIndex, isPrimary := s.colIndexInfo(indexes, col.ColumnName)
- if col.DataType == "decimal" {
- tableInfo.HasDecimal = true
- }
- if col.DataType == "datetime" {
- tableInfo.HasTime = true
- }
- if col.DataType == "json" {
- tableInfo.HasJson = true
- }
- filedInfo := &FieldInfo{
- Name: CapitalizeOrLower(toCamelCase(col.ColumnName)),
- Type: s.sqlTypeToGoType(col.DataType, col.ColumnType),
- GormTag: "",
- JsonTag: fmt.Sprintf(`json:"%s"`, toCamelCase(col.ColumnName)),
- Comment: col.ColumnComment,
- DefaultValue: col.ColumnDefault,
- IsNullable: col.IsNullable == "YES",
- IsPrimary: isPrimary,
- IsIndex: isIndex,
- IsUnique: isUnique,
- UniqueName: uniqueName,
- }
- // 组装 GormTag JsonTag
- // `json:"userId" gorm:"uniqueIndex:idx_asset_user_coin;comment:用户ID;NOT NULL"`
- // column:rate;comment:汇率
- gormTags := make([]string, 0)
- gormTags = append(gormTags, fmt.Sprintf("column:%s", col.ColumnName))
- gormTags = append(gormTags, fmt.Sprintf("type:%s", col.ColumnType))
- gormTags = append(gormTags, fmt.Sprintf("comment:%s", col.ColumnComment))
- if isIndex {
- if isPrimary {
- gormTags = append(gormTags, "primarykey")
- } else {
- gormTags = append(gormTags, "index")
- }
- }
- if isUnique && !isPrimary {
- gormTags = append(gormTags, fmt.Sprintf("unique:%s", filedInfo.UniqueName))
- }
- if !filedInfo.IsNullable {
- gormTags = append(gormTags, "NOT NULL")
- }
- filedInfo.GormTag = fmt.Sprintf(`gorm:"%s"`, strings.Join(gormTags, ";"))
- tableInfo.Fields = append(tableInfo.Fields, filedInfo)
- }
- return tableInfo, nil
- }
- func (s *ModelRegister) colUniqueInfo(items []*Constraint, colName string) (string, bool) {
- for _, item := range items {
- if item.ColumnName == colName {
- return item.ConstraintName, true
- }
- }
- return "", false
- }
- func (s *ModelRegister) colIndexInfo(items []*IndexInfo, colName string) (isIndex bool, isPrimary bool) {
- for _, item := range items {
- if item.ColumnName == colName {
- return true, item.IndexName == "PRIMARY"
- }
- }
- return false, false
- }
- // 获取列信息
- func (s *ModelRegister) getTableCols(db *gorm.DB) ([]*Column, error) {
- columns := make([]*Column, 0)
- if err := db.Raw(`
- SELECT
- COLUMN_NAME,
- DATA_TYPE,
- COLUMN_TYPE,
- COLUMN_KEY,
- IS_NULLABLE,
- COLUMN_DEFAULT,
- COLUMN_COMMENT,
- NUMERIC_PRECISION,
- NUMERIC_SCALE
- FROM INFORMATION_SCHEMA.COLUMNS
- WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
- ORDER BY ORDINAL_POSITION`, s.TableName).Scan(&columns).Error; err != nil {
- panic(err)
- }
- return columns, nil
- }
- // 获取表唯一约束信息
- func (s *ModelRegister) getTableConstraint(db *gorm.DB) ([]*Constraint, error) {
- constraints := make([]*Constraint, 0)
- if err := db.Raw(`
- SELECT
- k.CONSTRAINT_NAME,
- k.COLUMN_NAME,
- k.REFERENCED_TABLE_NAME,
- k.REFERENCED_COLUMN_NAME
- FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t
- JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE k
- USING(CONSTRAINT_NAME,TABLE_SCHEMA,TABLE_NAME)
- WHERE t.TABLE_SCHEMA = DATABASE()
- AND t.TABLE_NAME = ?
- AND t.CONSTRAINT_TYPE = 'UNIQUE'
- ORDER BY k.ORDINAL_POSITION`, s.TableName).Scan(&constraints).Error; err != nil {
- return nil, err
- }
- return constraints, nil
- }
- // 获取表索引信息
- func (s *ModelRegister) getTableIndexInfo(db *gorm.DB) ([]*IndexInfo, error) {
- // 获取索引信息
- indexes := make([]*IndexInfo, 0)
- if err := db.Raw(`
- SELECT
- INDEX_NAME,
- COLUMN_NAME,
- NON_UNIQUE
- FROM INFORMATION_SCHEMA.STATISTICS
- WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
- ORDER BY INDEX_NAME, SEQ_IN_INDEX`, s.TableName).Scan(&indexes).Error; err != nil {
- return nil, err
- }
- return indexes, nil
- }
- // SQL类型转Go类型
- func (s *ModelRegister) sqlTypeToGoType(dataType, columnType string) string {
- switch strings.ToLower(dataType) {
- case "tinyint":
- if strings.Contains(columnType, "tinyint(1)") {
- return "bool"
- }
- return "int8"
- case "smallint":
- return "int16"
- case "mediumint", "int":
- return "int"
- case "bigint":
- return "int64"
- case "float":
- return "float32"
- case "double":
- return "float64"
- case "char", "varchar", "tinytext", "text", "mediumtext", "longtext", "enum", "set":
- return "string"
- case "date", "datetime", "timestamp", "time":
- return "time.Time"
- case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
- return "[]byte"
- case "json":
- return "datatypes.JSON"
- case "bit":
- return "[]uint8"
- case "decimal":
- return "decimal.Decimal"
- default:
- return "interface{}"
- }
- }
|