jwt.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/golang-jwt/jwt/v5"
  7. "go_server/base/config"
  8. "go_server/base/core"
  9. "go_server/model/common/response"
  10. "net/http"
  11. "strings"
  12. "time"
  13. )
  14. const AuthorizationHeader = "Authorization"
  15. // JwtMiddleware JWT中间件, 强制要求用户登录
  16. func JwtMiddleware() gin.HandlerFunc {
  17. return func(c *gin.Context) {
  18. tokenString := c.GetHeader(AuthorizationHeader)
  19. if tokenString == "" {
  20. c.AbortWithStatusJSON(http.StatusUnauthorized, response.ErrorObjByCode(response.ResponseCodeMissAuthToken))
  21. return
  22. }
  23. // 解析JWT
  24. member, err := ParseJWT(tokenString)
  25. if err != nil || member == nil || member.ExpiresAt.Time.Before(time.Now()) {
  26. c.AbortWithStatusJSON(http.StatusUnauthorized, response.ErrorObjByCode(response.ResponseCodeTokenInvalid))
  27. return
  28. }
  29. // token有效,设置用户信息到上下文
  30. setClaimsToContext(c, member)
  31. // 继续执行
  32. c.Next()
  33. }
  34. }
  35. // setClaimsToContext 设置用户信息到上下文
  36. func setClaimsToContext(c *gin.Context, member *MyClaims) {
  37. c.Set("userId", member.UserID)
  38. c.Set("roleId", member.RoleId)
  39. c.Set("exp", member.ExpiresAt)
  40. c.Set("issuer", member.Issuer)
  41. }
  42. // OptionalJwtMiddleware 允许用户携带 JWT,但不强制要求登录
  43. func OptionalJwtMiddleware() gin.HandlerFunc {
  44. return func(c *gin.Context) {
  45. tokenString := c.GetHeader(AuthorizationHeader)
  46. if tokenString != "" {
  47. // 解析JWT
  48. member, err := ParseJWT(tokenString)
  49. if err == nil {
  50. if member != nil {
  51. setClaimsToContext(c, member)
  52. }
  53. } else {
  54. core.Log.Infof("OptionalJwtMiddleware:%s", err.Error())
  55. }
  56. }
  57. c.Next()
  58. }
  59. }
  60. type Member struct {
  61. ID int64
  62. RoleId int64
  63. }
  64. // Secret key used to sign the JWT token (In production, use environment variables to keep it secure)
  65. var jwtSecretStr = "admin.secret.1234565"
  66. var jwtSecretByte = []byte(jwtSecretStr)
  67. var jwtExpireDuration = int64(24) // in seconds
  68. var jwtIssuer = "issuer"
  69. // MyClaims 定义了JWT中的自定义声明
  70. type MyClaims struct {
  71. UserID int64 `json:"userId"`
  72. RoleId int64 `json:"roleId"`
  73. jwt.RegisteredClaims
  74. }
  75. func init() {
  76. jwtSecretStr = config.EnvConf().JWT.SigningKey
  77. jwtSecretByte = []byte(jwtSecretStr)
  78. jwtExpireDuration = config.EnvConf().JWT.ExpiresTime
  79. jwtIssuer = config.EnvConf().JWT.Issuer
  80. }
  81. // GenerateJWT 根据用户信息生成JWT
  82. func GenerateJWT(user Member) (string, error) {
  83. // 设置JWT的过期时间
  84. d := time.Duration(jwtExpireDuration) * time.Hour
  85. expirationTime := time.Now().Add(d)
  86. // 创建JWT的声明
  87. claims := MyClaims{
  88. UserID: user.ID,
  89. RoleId: user.RoleId,
  90. RegisteredClaims: jwt.RegisteredClaims{
  91. ID: fmt.Sprintf("%d", user.ID), // 设置ID
  92. ExpiresAt: jwt.NewNumericDate(expirationTime), // 设置过期时间
  93. Issuer: jwtIssuer, // 设置签发者
  94. },
  95. }
  96. // 使用HS256算法签署JWT
  97. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  98. // 生成并返回签名后的JWT字符串
  99. tokenString, err := token.SignedString(jwtSecretByte)
  100. if err != nil {
  101. return "", err
  102. }
  103. return tokenString, nil
  104. }
  105. // ParseJWT 解析和验证JWT
  106. func ParseJWT(tokenString string) (*MyClaims, error) {
  107. // 定义一个空的声明对象
  108. claims := &MyClaims{}
  109. // 去掉 "Bearer " 部分
  110. tokenString = strings.TrimPrefix(tokenString, "Bearer ")
  111. // 解析JWT
  112. token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
  113. // 验证签名算法
  114. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  115. return nil, errors.New("unexpected signing method")
  116. }
  117. return jwtSecretByte, nil
  118. })
  119. if err != nil || !token.Valid {
  120. return nil, errors.New("invalid token")
  121. }
  122. return claims, nil
  123. }
  124. // 解析gin上下文中的用户信息
  125. func ParseUser(c *gin.Context) *MyClaims {
  126. userId, _ := c.Get("userId")
  127. roleId, _ := c.Get("openId")
  128. issuer, _ := c.Get("issuer")
  129. claims := &MyClaims{
  130. UserID: userId.(int64),
  131. RoleId: roleId.(int64),
  132. RegisteredClaims: jwt.RegisteredClaims{
  133. Issuer: issuer.(string),
  134. },
  135. }
  136. return claims
  137. }