list_query.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package base
  2. import (
  3. "fmt"
  4. "go_server/utils"
  5. "gorm.io/gorm"
  6. "strings"
  7. )
  8. type BaseRequest struct {
  9. Id uint `json:"id" form:"id"`
  10. Data map[string]any `json:"data"`
  11. }
  12. type ListRequestColumn struct {
  13. Key string `json:"key"`
  14. Op *string `json:"op"`
  15. Value any `json:"value"`
  16. }
  17. // 定义范型
  18. type ListRequestInterface[T any] interface {
  19. Query(db *gorm.DB) (*ListResponse[T], error)
  20. }
  21. // ListRequest T 为 model
  22. type ListRequest[T any] struct {
  23. Current int `json:"current" query:"current" form:"current" ` //当前页
  24. PageSize int `json:"pageSize" query:"pageSize" form:"pageSize" ` //每页条数
  25. Column []ListRequestColumn `json:"column"`
  26. Order *string `json:"order" form:"order" `
  27. }
  28. func NewListRequest[T any]() *ListRequest[T] {
  29. request := new(ListRequest[T])
  30. request.PageSize = 10
  31. request.Current = 1
  32. return request
  33. }
  34. type ListResponse[T any] struct {
  35. List []T `json:"list"`
  36. Paging *Pagination `json:"paging"`
  37. }
  38. func (req *ListRequest[T]) Query(db *gorm.DB) (*ListResponse[T], error) {
  39. list := make([]T, 0)
  40. query := db.Model(&list)
  41. if len(req.Column) > 0 {
  42. for _, column := range req.Column {
  43. // 验证列名安全性
  44. if !utils.SQLSecurity.ValidateColumnName(column.Key) {
  45. continue // 跳过不安全的列名
  46. }
  47. // 验证操作符安全性
  48. if column.Op != nil && !utils.SQLSecurity.ValidateOperator(*column.Op) {
  49. continue // 跳过不安全的操作符
  50. }
  51. if column.Op != nil && !strings.Contains(">,=,<,>=,<=,between,like", *column.Op) {
  52. continue
  53. }
  54. switch {
  55. case column.Op != nil && *column.Op == "between":
  56. if t, ok := column.Value.([]any); ok {
  57. if len(t) > 1 {
  58. query.Where(fmt.Sprintf("`%s` %s ? and ?", column.Key, *column.Op), t[0], t[1])
  59. }
  60. }
  61. case column.Op != nil && *column.Op == "like":
  62. if t, ok := column.Value.(string); ok {
  63. query.Where(fmt.Sprintf("`%s` like ?", column.Key), "%"+strings.TrimSpace(t)+"%")
  64. }
  65. case column.Op != nil:
  66. query.Where(fmt.Sprintf("`%s` %s ?", column.Key, *column.Op), column.Value)
  67. default:
  68. query.Where(fmt.Sprintf("`%s` = ?", column.Key), column.Value)
  69. }
  70. }
  71. }
  72. // 验证ORDER BY安全性
  73. if req.Order != nil && utils.SQLSecurity.ValidateOrderBy(*req.Order) {
  74. query = query.Order(utils.WordsToSnakeCase(*req.Order))
  75. }
  76. paging := NewPagination()
  77. paging.Current = req.Current
  78. paging.PageSize = req.PageSize
  79. if err := query.Count(&paging.Total).Error; err != nil {
  80. return nil, err
  81. }
  82. paging.Computer() //分页器计数
  83. query = query.Limit(paging.PageSize).Offset(paging.StartNums)
  84. err := query.Find(&list).Error
  85. if err != nil {
  86. return nil, err
  87. }
  88. return &ListResponse[T]{
  89. List: list,
  90. Paging: paging,
  91. }, nil
  92. }