ast_router_root_enter.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package ams_ast
  2. import (
  3. "bytes"
  4. "fmt"
  5. "go/ast"
  6. "go/parser"
  7. "go/printer"
  8. "go/token"
  9. "os"
  10. "path/filepath"
  11. )
  12. // 根路由注册器
  13. type RootRouterRegister struct {
  14. // demo : /Users/***/app_manage_system/go_server/router/enter.go
  15. TargetFile string `json:"targetFile"` // 目标文件-固定-构建时使用绝对路径 go_server/router/enter.go
  16. ModuleNames []string `json:"moduleNames"` // 模块名
  17. StructName string `json:"structName"` // 结构体名称 -- RouterGroup
  18. ValKey string `json:"valKey"` // 默认值 priRouters
  19. BaseAst
  20. }
  21. // 服务根路由入口文件地址
  22. func ServerRootRouterEnterFile() string {
  23. return serverRootRouterEnterFile
  24. }
  25. // 支持同时注册多个模块 moduleNames
  26. func BuildRootRouterRegister(alias []string) *RootRouterRegister {
  27. return &RootRouterRegister{
  28. TargetFile: serverRootRouterEnterFile,
  29. ModuleNames: alias,
  30. StructName: "RouterGroup",
  31. ValKey: "priRouters",
  32. }
  33. }
  34. func (r *RootRouterRegister) routerModulePkg(module string) string {
  35. return fmt.Sprintf("go_server/router/%s", module)
  36. }
  37. // 总路由注册
  38. func (r *RootRouterRegister) Inspect() error {
  39. // 读取目标文件
  40. src, err := os.ReadFile(r.TargetFile)
  41. if err != nil {
  42. return err
  43. }
  44. // 创建 FileSet 用于跟踪文件位置信息
  45. fSet := token.NewFileSet()
  46. // 解析源代码为 AST
  47. f, err := parser.ParseFile(fSet, "", src, 0)
  48. if err != nil {
  49. return err
  50. }
  51. ast.Inspect(f, func(n ast.Node) bool {
  52. if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.VAR {
  53. for _, spec := range decl.Specs {
  54. if valueSpec, ok := spec.(*ast.ValueSpec); ok {
  55. for i, name := range valueSpec.Names {
  56. if name.Name == r.ValKey && len(valueSpec.Values) > i {
  57. if compLit, ok := valueSpec.Values[i].(*ast.CompositeLit); ok {
  58. // 元素确定
  59. if !r.isGlobalContextInterface(compLit) {
  60. return false
  61. }
  62. for _, moduleName := range r.ModuleNames {
  63. // 添加新的元素 ModelName.RouterGroup{}
  64. newElement := &ast.CompositeLit{
  65. Type: &ast.SelectorExpr{
  66. X: ast.NewIdent(moduleName),
  67. Sel: ast.NewIdent(r.StructName),
  68. },
  69. }
  70. // 检查是否已注入
  71. if !r.IsContainsModule(compLit.Elts, moduleName, r.StructName) {
  72. compLit.Elts = append(compLit.Elts, newElement)
  73. }
  74. // 确保导入了 模块对应的包
  75. normalizedPkg := filepath.ToSlash(moduleName)
  76. r.AddImport(f, r.routerModulePkg(normalizedPkg))
  77. }
  78. }
  79. }
  80. }
  81. }
  82. }
  83. }
  84. return true
  85. })
  86. var out []byte
  87. bf := bytes.NewBuffer(out)
  88. err = printer.Fprint(bf, fSet, f)
  89. if err != nil {
  90. return err
  91. }
  92. return os.WriteFile(r.TargetFile, bf.Bytes(), 0666)
  93. }
  94. // 可做进一步检查检查类型是否是 []global.ContextInterface
  95. func (r *RootRouterRegister) isGlobalContextInterface(compLit *ast.CompositeLit) bool {
  96. if arrType, ok := compLit.Type.(*ast.ArrayType); ok {
  97. if selExpr, ok := arrType.Elt.(*ast.SelectorExpr); ok {
  98. if xIdent, ok := selExpr.X.(*ast.Ident); ok && xIdent.Name == "global" {
  99. if selExpr.Sel.Name == "ContextInterface" {
  100. // 确认是我们要找的类型
  101. return true
  102. }
  103. }
  104. }
  105. }
  106. return false
  107. }
  108. // 总路由回滚 -- 仅支持单一模块回滚移除
  109. func (r *RootRouterRegister) RollbackRootRouter() error {
  110. // 读取目标文件
  111. src, err := os.ReadFile(r.TargetFile)
  112. if err != nil {
  113. return err
  114. }
  115. // 创建 FileSet 用于跟踪文件位置信息
  116. fSet := token.NewFileSet()
  117. // 解析源代码为 AST
  118. f, err := parser.ParseFile(fSet, "", src, 0)
  119. if err != nil {
  120. return err
  121. }
  122. ast.Inspect(f, func(n ast.Node) bool {
  123. if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.VAR {
  124. for _, spec := range decl.Specs {
  125. if valueSpec, ok := spec.(*ast.ValueSpec); ok {
  126. for i, name := range valueSpec.Names {
  127. if name.Name == r.ValKey && len(valueSpec.Values) > i {
  128. if compLit, ok := valueSpec.Values[i].(*ast.CompositeLit); ok {
  129. // 元素确定
  130. if !r.isGlobalContextInterface(compLit) {
  131. return false
  132. }
  133. newEls := make([]ast.Expr, 0, len(compLit.Elts))
  134. removeImport := false
  135. for _, elt := range compLit.Elts {
  136. if !r.IsTargetModule(elt, r.ModuleNames[0], r.StructName) {
  137. newEls = append(newEls, elt)
  138. } else {
  139. removeImport = true
  140. }
  141. if removeImport {
  142. // 确保删除模块对应的包
  143. r.RemoveImport(f, r.routerModulePkg(r.ModuleNames[0]))
  144. }
  145. }
  146. compLit.Elts = newEls
  147. }
  148. }
  149. }
  150. }
  151. }
  152. }
  153. return true
  154. })
  155. var out []byte
  156. bf := bytes.NewBuffer(out)
  157. err = printer.Fprint(bf, fSet, f)
  158. if err != nil {
  159. return err
  160. }
  161. err = os.Remove(r.TargetFile)
  162. if err != nil {
  163. return err
  164. }
  165. return os.WriteFile(r.TargetFile, bf.Bytes(), 0666)
  166. }