rate_limiter.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package core
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/redis/go-redis/v9"
  6. "time"
  7. )
  8. // 限流器
  9. type RateLimiter struct {
  10. client redis.UniversalClient
  11. }
  12. func NewRateLimiter(client redis.UniversalClient) *RateLimiter {
  13. return &RateLimiter{client: client}
  14. }
  15. func (r *RateLimiter) IsAllowedSimple(ctx context.Context, key string, limit int, window time.Duration) (bool, error) {
  16. // 先检查 key 的类型,如果是错误类型则删除
  17. keyType, err := r.client.Type(ctx, key).Result()
  18. if err != nil {
  19. return false, err
  20. }
  21. // 如果 key 存在且不是 string 类型,删除它
  22. if keyType != "none" && keyType != "string" {
  23. r.client.Del(ctx, key)
  24. }
  25. // 使用 INCR 和 EXPIRE 的简化方案
  26. count, err := r.client.Incr(ctx, key).Result()
  27. if err != nil {
  28. return false, err
  29. }
  30. if count == 1 {
  31. // 第一次设置时添加过期时间
  32. r.client.Expire(ctx, key, window)
  33. }
  34. return count <= int64(limit), nil
  35. }
  36. // 检查是否可以执行方法
  37. func (r *RateLimiter) CanExecuteMethod(ctx context.Context, methodName string, info string) (bool, error) {
  38. key := fmt.Sprintf("rate_limit:%s:%s", methodName, info)
  39. return r.IsAllowedSimple(ctx, key, 10, 10*time.Minute)
  40. }
  41. // ClearLimit 清除指定 key 的限流缓存
  42. func (r *RateLimiter) ClearLimit(ctx context.Context, methodName string, info string) error {
  43. key := fmt.Sprintf("rate_limit:%s:%s", methodName, info)
  44. return r.client.Del(ctx, key).Err()
  45. }
  46. // ClearLimitByPattern 根据模式匹配清除多个限流缓存
  47. func (r *RateLimiter) ClearLimitByPattern(ctx context.Context, pattern string) error {
  48. // 使用 SCAN 命令避免在大量 key 时阻塞 Redis
  49. iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
  50. var keys []string
  51. for iter.Next(ctx) {
  52. keys = append(keys, iter.Val())
  53. }
  54. if err := iter.Err(); err != nil {
  55. return err
  56. }
  57. if len(keys) > 0 {
  58. return r.client.Del(ctx, keys...).Err()
  59. }
  60. return nil
  61. }
  62. func (r *RateLimiter) ClearMethodLimit(ctx context.Context) error {
  63. key := fmt.Sprintf("rate_limit:*")
  64. return r.ClearLimitByPattern(ctx, key)
  65. }