ast_model_enter.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. package ams_ast
  2. import (
  3. "fmt"
  4. "gorm.io/gorm"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "text/template"
  9. )
  10. // 表信息结构体
  11. type TableInfo struct {
  12. Module string
  13. TableName string
  14. StructName string
  15. Fields []*FieldInfo
  16. HasDecimal bool
  17. HasTime bool
  18. HasJson bool
  19. }
  20. // 字段信息结构体
  21. type FieldInfo struct {
  22. Name string
  23. Type string
  24. GormTag string
  25. JsonTag string
  26. Comment string
  27. DefaultValue string
  28. IsPrimary bool // 是否主键
  29. IsNullable bool // 是否可为空
  30. IsIndex bool // 是否索引
  31. IsUnique bool // 是否唯一约束索引
  32. UniqueName string // 约束名
  33. }
  34. // 获取表的列信息
  35. type Column struct {
  36. ColumnName string `gorm:"column:COLUMN_NAME"`
  37. DataType string `gorm:"column:DATA_TYPE"`
  38. ColumnType string `gorm:"column:COLUMN_TYPE"`
  39. ColumnKey string `gorm:"column:COLUMN_KEY"`
  40. IsNullable string `gorm:"column:IS_NULLABLE"`
  41. ColumnDefault string `gorm:"column:COLUMN_DEFAULT"`
  42. ColumnComment string `gorm:"column:COLUMN_COMMENT"`
  43. NumericPrecision *int64 `gorm:"column:NUMERIC_PRECISION"`
  44. NumericScale *int64 `gorm:"column:NUMERIC_SCALE"`
  45. }
  46. // 唯一约束信息
  47. type Constraint struct {
  48. ConstraintName string `gorm:"column:CONSTRAINT_NAME"`
  49. ColumnName string `gorm:"column:COLUMN_NAME"`
  50. ReferencedTableName string `gorm:"column:REFERENCED_TABLE_NAME"`
  51. ReferencedColumnName string `gorm:"column:REFERENCED_COLUMN_NAME"`
  52. }
  53. // 索引信息
  54. type IndexInfo struct {
  55. IndexName string `gorm:"column:INDEX_NAME"`
  56. ColumnName string `gorm:"column:COLUMN_NAME"`
  57. NonUnique int `gorm:"column:NON_UNIQUE"`
  58. }
  59. // 模型注册器
  60. type ModelRegister struct {
  61. DbAlias string // 数据库别名
  62. ModelTemplateFile string `json:"modelTemplateFile"`
  63. TargetDir string `json:"targetDir"` // 文件生成目标文件夹
  64. TargetFile string `json:"targetFile"` // 模型文件
  65. TableName string `json:"tableName"`
  66. StructName string `json:"structName"` // 生成的结构体名称
  67. BaseAst
  68. }
  69. // 数据库 表名
  70. func BuildModelRegister(dbAlias, dbName, tableName string) *ModelRegister {
  71. register := &ModelRegister{
  72. DbAlias: dbAlias, // 数据库别名 biz
  73. TableName: tableName, // 表名称 demo_user
  74. StructName: CapitalizeOrLower(toCamelCase(tableName)), // 首字母大写驼峰写法 DemoUser
  75. }
  76. register.ModelTemplateFile = register.modelTemplateFile() // 模版文件地址
  77. register.TargetFile = register.modelTargetFile(tableName) // 模型目标文件
  78. return register
  79. }
  80. // 生成表结构
  81. func (s *ModelRegister) Initialize(db *gorm.DB) error {
  82. var files *template.Template
  83. files, err := template.ParseFiles(s.ModelTemplateFile)
  84. if err != nil {
  85. return err
  86. }
  87. //slog.Infof("模型模版文件:%s", s.ModelTemplateFile)
  88. err = os.MkdirAll(filepath.Dir(s.TargetFile), os.ModePerm)
  89. if err != nil {
  90. return err
  91. }
  92. if FileExists(s.TargetFile) {
  93. err = os.Remove(s.TargetFile)
  94. if err != nil {
  95. return err
  96. }
  97. }
  98. var file *os.File
  99. file, err = os.Create(s.TargetFile)
  100. if err != nil {
  101. return err
  102. }
  103. data, err := s.getTableInfo(db)
  104. if err != nil {
  105. return err
  106. }
  107. err = files.Execute(file, data)
  108. _ = file.Close()
  109. if err != nil {
  110. return err
  111. }
  112. return nil
  113. }
  114. // 回滚
  115. func (s *ModelRegister) RollbackModel() error {
  116. if FileExists(s.TargetFile) {
  117. err := os.Remove(s.TargetFile)
  118. if err != nil {
  119. return err
  120. }
  121. }
  122. return nil
  123. }
  124. // 模板文件地址
  125. func (s *ModelRegister) modelTemplateFile() string {
  126. return filepath.Join(template_root, "model/model.go.tpl")
  127. }
  128. // 生成模型文件路径
  129. func (s *ModelRegister) modelTargetFile(tableName string) string {
  130. return filepath.Join(s.modelTargetDir(), toSnakeCase(tableName)+".go")
  131. }
  132. // 模型目标文件夹 -- go_server/server/model/server_model_biz/DbAlias
  133. func (s *ModelRegister) modelTargetDir() string {
  134. return filepath.Join(modelTargetPath, server_model_biz, s.DbAlias)
  135. }
  136. // 获取表信息
  137. func (s *ModelRegister) getTableInfo(db *gorm.DB) (*TableInfo, error) {
  138. tableInfo := &TableInfo{
  139. Module: s.DbAlias,
  140. TableName: s.TableName,
  141. StructName: s.StructName,
  142. Fields: make([]*FieldInfo, 0),
  143. HasDecimal: false,
  144. HasTime: false,
  145. HasJson: false,
  146. }
  147. cols, err := s.getTableCols(db)
  148. if err != nil {
  149. return nil, err
  150. }
  151. constraint, err := s.getTableConstraint(db)
  152. if err != nil {
  153. return nil, err
  154. }
  155. indexes, err := s.getTableIndexInfo(db)
  156. if err != nil {
  157. return nil, err
  158. }
  159. for _, col := range cols {
  160. // 是否为Primary index
  161. // 获取唯一约束信息
  162. uniqueName, isUnique := s.colUniqueInfo(constraint, col.ColumnName)
  163. isIndex, isPrimary := s.colIndexInfo(indexes, col.ColumnName)
  164. if col.DataType == "decimal" {
  165. tableInfo.HasDecimal = true
  166. }
  167. if col.DataType == "datetime" {
  168. tableInfo.HasTime = true
  169. }
  170. if col.DataType == "json" {
  171. tableInfo.HasJson = true
  172. }
  173. filedInfo := &FieldInfo{
  174. Name: CapitalizeOrLower(toCamelCase(col.ColumnName)),
  175. Type: s.sqlTypeToGoType(col.DataType, col.ColumnType),
  176. GormTag: "",
  177. JsonTag: fmt.Sprintf(`json:"%s"`, toCamelCase(col.ColumnName)),
  178. Comment: col.ColumnComment,
  179. DefaultValue: col.ColumnDefault,
  180. IsNullable: col.IsNullable == "YES",
  181. IsPrimary: isPrimary,
  182. IsIndex: isIndex,
  183. IsUnique: isUnique,
  184. UniqueName: uniqueName,
  185. }
  186. // 组装 GormTag JsonTag
  187. // `json:"userId" gorm:"uniqueIndex:idx_asset_user_coin;comment:用户ID;NOT NULL"`
  188. // column:rate;comment:汇率
  189. gormTags := make([]string, 0)
  190. gormTags = append(gormTags, fmt.Sprintf("column:%s", col.ColumnName))
  191. gormTags = append(gormTags, fmt.Sprintf("type:%s", col.ColumnType))
  192. gormTags = append(gormTags, fmt.Sprintf("comment:%s", col.ColumnComment))
  193. if isIndex {
  194. if isPrimary {
  195. gormTags = append(gormTags, "primarykey")
  196. } else {
  197. gormTags = append(gormTags, "index")
  198. }
  199. }
  200. if isUnique && !isPrimary {
  201. gormTags = append(gormTags, fmt.Sprintf("unique:%s", filedInfo.UniqueName))
  202. }
  203. if !filedInfo.IsNullable {
  204. gormTags = append(gormTags, "NOT NULL")
  205. }
  206. filedInfo.GormTag = fmt.Sprintf(`gorm:"%s"`, strings.Join(gormTags, ";"))
  207. tableInfo.Fields = append(tableInfo.Fields, filedInfo)
  208. }
  209. return tableInfo, nil
  210. }
  211. func (s *ModelRegister) colUniqueInfo(items []*Constraint, colName string) (string, bool) {
  212. for _, item := range items {
  213. if item.ColumnName == colName {
  214. return item.ConstraintName, true
  215. }
  216. }
  217. return "", false
  218. }
  219. func (s *ModelRegister) colIndexInfo(items []*IndexInfo, colName string) (isIndex bool, isPrimary bool) {
  220. for _, item := range items {
  221. if item.ColumnName == colName {
  222. return true, item.IndexName == "PRIMARY"
  223. }
  224. }
  225. return false, false
  226. }
  227. // 获取列信息
  228. func (s *ModelRegister) getTableCols(db *gorm.DB) ([]*Column, error) {
  229. columns := make([]*Column, 0)
  230. if err := db.Raw(`
  231. SELECT
  232. COLUMN_NAME,
  233. DATA_TYPE,
  234. COLUMN_TYPE,
  235. COLUMN_KEY,
  236. IS_NULLABLE,
  237. COLUMN_DEFAULT,
  238. COLUMN_COMMENT,
  239. NUMERIC_PRECISION,
  240. NUMERIC_SCALE
  241. FROM INFORMATION_SCHEMA.COLUMNS
  242. WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
  243. ORDER BY ORDINAL_POSITION`, s.TableName).Scan(&columns).Error; err != nil {
  244. panic(err)
  245. }
  246. return columns, nil
  247. }
  248. // 获取表唯一约束信息
  249. func (s *ModelRegister) getTableConstraint(db *gorm.DB) ([]*Constraint, error) {
  250. constraints := make([]*Constraint, 0)
  251. if err := db.Raw(`
  252. SELECT
  253. k.CONSTRAINT_NAME,
  254. k.COLUMN_NAME,
  255. k.REFERENCED_TABLE_NAME,
  256. k.REFERENCED_COLUMN_NAME
  257. FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t
  258. JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE k
  259. USING(CONSTRAINT_NAME,TABLE_SCHEMA,TABLE_NAME)
  260. WHERE t.TABLE_SCHEMA = DATABASE()
  261. AND t.TABLE_NAME = ?
  262. AND t.CONSTRAINT_TYPE = 'UNIQUE'
  263. ORDER BY k.ORDINAL_POSITION`, s.TableName).Scan(&constraints).Error; err != nil {
  264. return nil, err
  265. }
  266. return constraints, nil
  267. }
  268. // 获取表索引信息
  269. func (s *ModelRegister) getTableIndexInfo(db *gorm.DB) ([]*IndexInfo, error) {
  270. // 获取索引信息
  271. indexes := make([]*IndexInfo, 0)
  272. if err := db.Raw(`
  273. SELECT
  274. INDEX_NAME,
  275. COLUMN_NAME,
  276. NON_UNIQUE
  277. FROM INFORMATION_SCHEMA.STATISTICS
  278. WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?
  279. ORDER BY INDEX_NAME, SEQ_IN_INDEX`, s.TableName).Scan(&indexes).Error; err != nil {
  280. return nil, err
  281. }
  282. return indexes, nil
  283. }
  284. // SQL类型转Go类型
  285. func (s *ModelRegister) sqlTypeToGoType(dataType, columnType string) string {
  286. switch strings.ToLower(dataType) {
  287. case "tinyint":
  288. if strings.Contains(columnType, "tinyint(1)") {
  289. return "bool"
  290. }
  291. return "int8"
  292. case "smallint":
  293. return "int16"
  294. case "mediumint", "int":
  295. return "int"
  296. case "bigint":
  297. return "int64"
  298. case "float":
  299. return "float32"
  300. case "double":
  301. return "float64"
  302. case "char", "varchar", "tinytext", "text", "mediumtext", "longtext", "enum", "set":
  303. return "string"
  304. case "date", "datetime", "timestamp", "time":
  305. return "time.Time"
  306. case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
  307. return "[]byte"
  308. case "json":
  309. return "datatypes.JSON"
  310. case "bit":
  311. return "[]uint8"
  312. case "decimal":
  313. return "decimal.Decimal"
  314. default:
  315. return "interface{}"
  316. }
  317. }