Quellcode durchsuchen

修复了密码明文存储、暴力破解防护(4处)、事务/原子操作安全(6处)、SQL 正确性(ROW_NUMBER、CONCAT)、GORM 状态污染、N+1 查询优化、登录响应补全 email、以及 4处硬编码中文改用 ctx.I18n() 国际化

urbanu vor 1 Monat
Ursprung
Commit
da34c6b772

+ 240 - 86
apis/daytask/auth.go

@@ -12,8 +12,10 @@ import (
 	"io"
 	"math/big"
 	"net/http"
+	"net/smtp"
 	"net/url"
 	"regexp"
+	"strconv"
 	"strings"
 	"time"
 
@@ -32,9 +34,10 @@ func generateInviteCode() string {
 	return string(code)
 }
 
-// generateUid 生成用户UID
+// generateUid 生成用户UID(时间戳+随机数,防止并发碰撞)
 func generateUid() string {
-	return fmt.Sprintf("DT%d", time.Now().UnixNano()/1000000)
+	n, _ := rand.Int(rand.Reader, big.NewInt(9999))
+	return fmt.Sprintf("DT%d%04d", time.Now().UnixNano()/1000000, n.Int64())
 }
 
 // sendSmsBao 调用短信宝发送短信
@@ -86,7 +89,35 @@ func sendSmsBao(user, pass, sign, phone, code string) error {
 	return nil
 }
 
-// SendSmsCode 发送短信验证码
+// sendEmail 发送邮件验证码
+func sendEmail(host string, port int, user, pass, fromName, toEmail, code string) error {
+	from := user
+	subject := "Verification Code"
+	body := fmt.Sprintf(`
+		<div style="padding:20px;font-family:Arial,sans-serif;">
+			<h2 style="color:#ffc300;">Vitiens</h2>
+			<p>Your verification code is:</p>
+			<h1 style="color:#ffc300;font-size:36px;letter-spacing:8px;">%s</h1>
+			<p>This code will expire in 5 minutes.</p>
+			<p style="color:#999;font-size:12px;">If you did not request this code, please ignore this email.</p>
+		</div>
+	`, code)
+
+	// 构建邮件内容(固定header顺序)
+	msg := fmt.Sprintf("From: %s <%s>\r\n", fromName, from)
+	msg += fmt.Sprintf("To: %s\r\n", toEmail)
+	msg += fmt.Sprintf("Subject: %s\r\n", subject)
+	msg += "MIME-Version: 1.0\r\n"
+	msg += "Content-Type: text/html; charset=UTF-8\r\n"
+	msg += "\r\n" + body
+
+	addr := fmt.Sprintf("%s:%d", host, port)
+	auth := smtp.PlainAuth("", user, pass, host)
+
+	return smtp.SendMail(addr, auth, from, []string{toEmail}, []byte(msg))
+}
+
+// SendSmsCode 发送短信/邮件验证码
 func (s *Server) SendSmsCode(c *gin.Context) {
 	ctx := s.FromContext(c)
 	db := s.DB()
@@ -103,29 +134,39 @@ func (s *Server) SendSmsCode(c *gin.Context) {
 		return
 	}
 
-	// 目前只支持手机号
-	if req.Type != "phone" {
+	// 验证类型
+	if req.Type != "phone" && req.Type != "email" {
 		ctx.Fail("unsupported_type")
 		return
 	}
 
-	// 验证手机号格式
-	phoneRegex := regexp.MustCompile(`^\d{9,15}$`)
-	if !phoneRegex.MatchString(req.Account) {
-		ctx.Fail("invalid_phone")
-		return
-	}
+	var accountKey string // 用于缓存的key
 
-	// 完整手机号(带区号)
-	fullPhone := req.Account
-	if req.AreaCode != "" {
-		fullPhone = req.AreaCode + req.Account
+	if req.Type == "phone" {
+		// 验证手机号格式
+		phoneRegex := regexp.MustCompile(`^\d{9,15}$`)
+		if !phoneRegex.MatchString(req.Account) {
+			ctx.Fail("invalid_phone")
+			return
+		}
+		accountKey = req.Account
+		if req.AreaCode != "" {
+			accountKey = req.AreaCode + req.Account
+		}
+	} else {
+		// 验证邮箱格式
+		emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
+		if !emailRegex.MatchString(req.Account) {
+			ctx.Fail("invalid_email")
+			return
+		}
+		accountKey = req.Account
 	}
 
 	// 防止重复发送
-	cacheKey := fmt.Sprintf("sms:%s:%s", req.Scene, fullPhone)
+	cacheKey := fmt.Sprintf("code:%s:%s", req.Scene, accountKey)
 	if !ctx.RepeatFilter(cacheKey, 60*time.Second) {
-		ctx.Fail("sms_send_too_fast")
+		ctx.Fail("code_send_too_fast")
 		return
 	}
 
@@ -133,36 +174,71 @@ func (s *Server) SendSmsCode(c *gin.Context) {
 	n, _ := rand.Int(rand.Reader, big.NewInt(900000))
 	code := fmt.Sprintf("%06d", n.Int64()+100000)
 
-	// 获取短信宝配置
-	var smsUser, smsPass, smsSign string
-	var config entity.DtConfig
-	if err := db.Where("`key` = ?", entity.ConfigKeySmsUser).First(&config).Error; err == nil {
-		smsUser = config.Value
-	}
-	if err := db.Where("`key` = ?", entity.ConfigKeySmsPass).First(&config).Error; err == nil {
-		smsPass = config.Value
-	}
-	if err := db.Where("`key` = ?", entity.ConfigKeySmsSign).First(&config).Error; err == nil {
-		smsSign = config.Value
-	}
-
-	// 调用短信宝发送验证码
-	if smsUser != "" && smsPass != "" {
-		if err := sendSmsBao(smsUser, smsPass, smsSign, fullPhone, code); err != nil {
-			ctx.Fail("sms_send_failed")
-			return
+	if req.Type == "phone" {
+		// 发送短信验证码
+		var smsUser, smsPass, smsSign string
+		var config entity.DtConfig
+		if err := db.Where("`key` = ?", entity.ConfigKeySmsUser).First(&config).Error; err == nil {
+			smsUser = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmsPass).First(&config).Error; err == nil {
+			smsPass = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmsSign).First(&config).Error; err == nil {
+			smsSign = config.Value
+		}
+		if smsUser != "" && smsPass != "" {
+			if err := sendSmsBao(smsUser, smsPass, smsSign, accountKey, code); err != nil {
+				ctx.Fail("sms_send_failed")
+				return
+			}
+		} else {
+			fmt.Printf("[DEV] SMS Code for %s: %s\n", accountKey, code)
 		}
 	} else {
-		// 开发模式:打印验证码到日志
-		fmt.Printf("[DEV] SMS Code for %s: %s\n", fullPhone, code)
+		// 发送邮件验证码
+		var smtpHost, smtpPortStr, smtpUser, smtpPass, smtpFrom string
+		var config entity.DtConfig
+		if err := db.Where("`key` = ?", entity.ConfigKeySmtpHost).First(&config).Error; err == nil {
+			smtpHost = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmtpPort).First(&config).Error; err == nil {
+			smtpPortStr = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmtpUser).First(&config).Error; err == nil {
+			smtpUser = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmtpPass).First(&config).Error; err == nil {
+			smtpPass = config.Value
+		}
+		if err := db.Where("`key` = ?", entity.ConfigKeySmtpFrom).First(&config).Error; err == nil {
+			smtpFrom = config.Value
+		}
+		if smtpFrom == "" {
+			smtpFrom = "Vitiens"
+		}
+
+		if smtpHost != "" && smtpUser != "" && smtpPass != "" {
+			smtpPort, _ := strconv.Atoi(smtpPortStr)
+			if smtpPort == 0 {
+				smtpPort = 587
+			}
+			if err := sendEmail(smtpHost, smtpPort, smtpUser, smtpPass, smtpFrom, req.Account, code); err != nil {
+				fmt.Printf("[ERROR] Send email failed: %v\n", err)
+				ctx.Fail("email_send_failed")
+				return
+			}
+		} else {
+			fmt.Printf("[DEV] Email Code for %s: %s\n", req.Account, code)
+		}
 	}
 
 	// 存储验证码到Redis
-	codeKey := fmt.Sprintf("smscode:%s:%s", req.Scene, fullPhone)
-	redisclient.DefaultClient().Set(c, codeKey, code, 5*time.Minute)
+	codeKey2 := fmt.Sprintf("verifycode:%s:%s", req.Scene, accountKey)
+	redisclient.DefaultClient().Set(c, codeKey2, code, 5*time.Minute)
 
 	ctx.OK(gin.H{
-		"message": "sms_sent",
+		"message": "code_sent",
 	})
 }
 
@@ -185,31 +261,49 @@ func (s *Server) Register(c *gin.Context) {
 		return
 	}
 
-	// 目前只支持手机号注册
-	if req.Type != "phone" {
+	// 验证类型
+	if req.Type != "phone" && req.Type != "email" {
 		ctx.Fail("unsupported_type")
 		return
 	}
 
-	// 完整手机号(带区号)
-	fullPhone := req.Account
-	if req.AreaCode != "" {
-		fullPhone = req.AreaCode + req.Account
+	// 账号Key
+	accountKey := req.Account
+	if req.Type == "phone" && req.AreaCode != "" {
+		accountKey = req.AreaCode + req.Account
 	}
 
-	// 验证短信验证码
-	codeKey := fmt.Sprintf("smscode:register:%s", fullPhone)
+	// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
+	failKey := fmt.Sprintf("code_fail:register:%s", accountKey)
+	failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
+	if failCount >= 5 {
+		ctx.Fail("too_many_attempts")
+		return
+	}
+
+	// 验证验证码
+	codeKey := fmt.Sprintf("verifycode:register:%s", accountKey)
 	storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
 	if err != nil || storedCode != req.Code {
+		redisclient.DefaultClient().Incr(c, failKey)
+		redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
 		ctx.Fail("invalid_code")
 		return
 	}
+	redisclient.DefaultClient().Del(c, failKey)
 
-	// 检查手机号是否已注册
+	// 检查是否已注册
 	var existUser entity.DtUser
-	if err := db.Where("phone = ?", fullPhone).First(&existUser).Error; err == nil {
-		ctx.Fail("phone_registered")
-		return
+	if req.Type == "phone" {
+		if err := db.Where("phone = ?", accountKey).First(&existUser).Error; err == nil {
+			ctx.Fail("phone_registered")
+			return
+		}
+	} else {
+		if err := db.Where("email = ?", accountKey).First(&existUser).Error; err == nil {
+			ctx.Fail("email_registered")
+			return
+		}
 	}
 
 	// 查找邀请人
@@ -232,17 +326,29 @@ func (s *Server) Register(c *gin.Context) {
 	var defaultLevel entity.DtUserLevel
 	db.Where("is_default = ?", 1).First(&defaultLevel)
 
+	// 昵称
+	nickname := ctx.I18n("user_prefix")
+	if len(req.Account) >= 4 {
+		nickname += req.Account[len(req.Account)-4:]
+	} else {
+		nickname += req.Account
+	}
+
 	// 创建用户
 	user := &entity.DtUser{
 		Uid:        generateUid(),
-		Phone:      fullPhone,
 		Password:   string(hashedPassword),
-		Nickname:   "用户" + req.Account[len(req.Account)-4:],
+		Nickname:   nickname,
 		ParentId:   parentId,
 		LevelId:    defaultLevel.Id,
 		InviteCode: generateInviteCode(),
 		Status:     1,
 	}
+	if req.Type == "phone" {
+		user.Phone = accountKey
+	} else {
+		user.Email = accountKey
+	}
 
 	tx := db.Begin()
 
@@ -256,15 +362,15 @@ func (s *Server) Register(c *gin.Context) {
 	if parentId > 0 {
 		tx.Model(&entity.DtUser{}).Where("id = ?", parentId).
 			Updates(map[string]interface{}{
-				"direct_invite_count": db.Raw("direct_invite_count + 1"),
-				"team_count":          db.Raw("team_count + 1"),
+				"direct_invite_count": tx.Raw("direct_invite_count + 1"),
+				"team_count":          tx.Raw("team_count + 1"),
 			})
 
 		// 更新上级的团队人数(多级)
 		var parent entity.DtUser
 		if err := tx.Where("id = ?", parentId).First(&parent).Error; err == nil && parent.ParentId > 0 {
 			tx.Model(&entity.DtUser{}).Where("id = ?", parent.ParentId).
-				Update("team_count", db.Raw("team_count + 1"))
+				Update("team_count", tx.Raw("team_count + 1"))
 		}
 	}
 
@@ -288,6 +394,7 @@ func (s *Server) Register(c *gin.Context) {
 			"nickname": user.Nickname,
 			"avatar":   user.Avatar,
 			"phone":    user.Phone,
+			"email":    user.Email,
 		},
 	})
 }
@@ -309,17 +416,30 @@ func (s *Server) LoginByPassword(c *gin.Context) {
 		return
 	}
 
-	// 完整手机号(带区号)
-	fullPhone := req.Account
-	if req.AreaCode != "" {
-		fullPhone = req.AreaCode + req.Account
+	// 检查登录尝试次数(防暴力破解,5次锁定15分钟)
+	loginFailKey := fmt.Sprintf("login_fail:%s", req.Account)
+	loginFailCount, _ := redisclient.DefaultClient().Get(c, loginFailKey).Int()
+	if loginFailCount >= 5 {
+		ctx.Fail("too_many_attempts")
+		return
 	}
 
 	// 查找用户
 	var user entity.DtUser
-	if err := db.Where("phone = ?", fullPhone).First(&user).Error; err != nil {
-		ctx.Fail("user_not_found")
-		return
+	if req.Type == "email" {
+		if err := db.Where("email = ?", req.Account).First(&user).Error; err != nil {
+			ctx.Fail("user_not_found")
+			return
+		}
+	} else {
+		fullPhone := req.Account
+		if req.AreaCode != "" {
+			fullPhone = req.AreaCode + req.Account
+		}
+		if err := db.Where("phone = ?", fullPhone).First(&user).Error; err != nil {
+			ctx.Fail("user_not_found")
+			return
+		}
 	}
 
 	// 检查用户状态
@@ -330,13 +450,16 @@ func (s *Server) LoginByPassword(c *gin.Context) {
 
 	// 验证密码
 	if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
+		redisclient.DefaultClient().Incr(c, loginFailKey)
+		redisclient.DefaultClient().Expire(c, loginFailKey, 15*time.Minute)
 		ctx.Fail("invalid_password")
 		return
 	}
+	redisclient.DefaultClient().Del(c, loginFailKey)
 
 	// 更新登录时间
 	db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
-		Update("last_login_at", time.Now().Unix())
+		Update("last_login_time", time.Now().Unix())
 
 	// 生成Token
 	token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
@@ -353,6 +476,7 @@ func (s *Server) LoginByPassword(c *gin.Context) {
 			"nickname": user.Nickname,
 			"avatar":   user.Avatar,
 			"phone":    user.Phone,
+			"email":    user.Email,
 		},
 	})
 }
@@ -379,13 +503,24 @@ func (s *Server) LoginBySms(c *gin.Context) {
 		fullPhone = req.AreaCode + req.Phone
 	}
 
+	// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
+	failKey := fmt.Sprintf("code_fail:login:%s", fullPhone)
+	failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
+	if failCount >= 5 {
+		ctx.Fail("too_many_attempts")
+		return
+	}
+
 	// 验证短信验证码
-	codeKey := fmt.Sprintf("smscode:login:%s", fullPhone)
+	codeKey := fmt.Sprintf("verifycode:login:%s", fullPhone)
 	storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
 	if err != nil || storedCode != req.Code {
+		redisclient.DefaultClient().Incr(c, failKey)
+		redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
 		ctx.Fail("invalid_code")
 		return
 	}
+	redisclient.DefaultClient().Del(c, failKey)
 
 	// 查找用户
 	var user entity.DtUser
@@ -402,7 +537,7 @@ func (s *Server) LoginBySms(c *gin.Context) {
 
 	// 更新登录时间
 	db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
-		Update("last_login_at", time.Now().Unix())
+		Update("last_login_time", time.Now().Unix())
 
 	// 删除验证码
 	redisclient.DefaultClient().Del(c, codeKey)
@@ -422,6 +557,7 @@ func (s *Server) LoginBySms(c *gin.Context) {
 			"nickname": user.Nickname,
 			"avatar":   user.Avatar,
 			"phone":    user.Phone,
+			"email":    user.Email,
 		},
 	})
 }
@@ -444,25 +580,43 @@ func (s *Server) ResetPassword(c *gin.Context) {
 		return
 	}
 
-	// 完整手机号(带区号)
-	fullPhone := req.Account
-	if req.AreaCode != "" {
-		fullPhone = req.AreaCode + req.Account
+	// 账号Key
+	accountKey := req.Account
+	if req.Type == "phone" && req.AreaCode != "" {
+		accountKey = req.AreaCode + req.Account
 	}
 
-	// 验证短信验证码
-	codeKey := fmt.Sprintf("smscode:reset:%s", fullPhone)
+	// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
+	failKey := fmt.Sprintf("code_fail:reset:%s", accountKey)
+	failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
+	if failCount >= 5 {
+		ctx.Fail("too_many_attempts")
+		return
+	}
+
+	// 验证验证码
+	codeKey := fmt.Sprintf("verifycode:reset:%s", accountKey)
 	storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
 	if err != nil || storedCode != req.Code {
+		redisclient.DefaultClient().Incr(c, failKey)
+		redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
 		ctx.Fail("invalid_code")
 		return
 	}
+	redisclient.DefaultClient().Del(c, failKey)
 
 	// 查找用户
 	var user entity.DtUser
-	if err := db.Where("phone = ?", fullPhone).First(&user).Error; err != nil {
-		ctx.Fail("user_not_found")
-		return
+	if req.Type == "email" {
+		if err := db.Where("email = ?", accountKey).First(&user).Error; err != nil {
+			ctx.Fail("user_not_found")
+			return
+		}
+	} else {
+		if err := db.Where("phone = ?", accountKey).First(&user).Error; err != nil {
+			ctx.Fail("user_not_found")
+			return
+		}
 	}
 
 	// 加密新密码
@@ -533,7 +687,7 @@ func (s *Server) OAuthLogin(c *gin.Context) {
 
 		// 更新登录时间
 		db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
-			Update("last_login_at", time.Now().Unix())
+			Update("last_login_time", time.Now().Unix())
 
 		token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
 		if err != nil {
@@ -569,15 +723,14 @@ func (s *Server) OAuthLogin(c *gin.Context) {
 
 	nickname := req.Nickname
 	if nickname == "" {
-		nickname = req.Provider + "用户"
+		nickname = req.Provider + ctx.I18n("user_prefix")
 	}
 
-	// 为OAuth用户生成唯一占位手机号
-	openIdPart := req.OpenId
-	if len(openIdPart) > 16 {
-		openIdPart = openIdPart[:16]
-	}
-	oauthPhone := fmt.Sprintf("oauth_%s_%s", req.Provider, openIdPart)
+	// 为OAuth用户生成唯一占位手机号(用MD5避免截断碰撞)
+	h := md5.New()
+	h.Write([]byte(req.OpenId))
+	openIdHash := hex.EncodeToString(h.Sum(nil))[:16]
+	oauthPhone := fmt.Sprintf("oauth_%s_%s", req.Provider, openIdHash)
 
 	user := &entity.DtUser{
 		Uid:        generateUid(),
@@ -619,8 +772,8 @@ func (s *Server) OAuthLogin(c *gin.Context) {
 	if parentId > 0 {
 		tx.Model(&entity.DtUser{}).Where("id = ?", parentId).
 			Updates(map[string]interface{}{
-				"direct_invite_count": db.Raw("direct_invite_count + 1"),
-				"team_count":          db.Raw("team_count + 1"),
+				"direct_invite_count": tx.Raw("direct_invite_count + 1"),
+				"team_count":          tx.Raw("team_count + 1"),
 			})
 	}
 
@@ -641,6 +794,7 @@ func (s *Server) OAuthLogin(c *gin.Context) {
 			"nickname": user.Nickname,
 			"avatar":   user.Avatar,
 			"phone":    user.Phone,
+			"email":    user.Email,
 		},
 		"isNew": true,
 	})

+ 20 - 10
apis/daytask/home.go

@@ -30,23 +30,25 @@ func (s *Server) HomeIndex(c *gin.Context) {
 
 	// 获取推荐任务
 	recommendTasks := make([]*entity.DtTask, 0)
-	query := db.Model(&entity.DtTask{}).Where("status = ?", 1)
 	// 尝试使用 is_recommend 字段,如果字段不存在则降级处理
-	if err := query.Where("is_recommend = ?", 1).Order("sort DESC, id DESC").Limit(10).Find(&recommendTasks).Error; err != nil {
+	if err := db.Model(&entity.DtTask{}).Where("status = ? AND is_recommend = ?", 1, 1).
+		Order("sort DESC, id DESC").Limit(10).Find(&recommendTasks).Error; err != nil {
 		if isColumnNotExistError(err) {
 			// 字段不存在,使用降级查询(不筛选推荐)
-			query.Order("sort DESC, id DESC").Limit(10).Find(&recommendTasks)
+			db.Model(&entity.DtTask{}).Where("status = ?", 1).
+				Order("sort DESC, id DESC").Limit(10).Find(&recommendTasks)
 		}
 	}
 
 	// 获取普通任务
 	normalTasks := make([]*entity.DtTask, 0)
-	query = db.Model(&entity.DtTask{}).Where("status = ? AND remain_count > 0", 1)
 	// 尝试使用 is_top 字段,如果字段不存在则降级处理
-	if err := query.Order("is_top DESC, created_at DESC").Limit(20).Find(&normalTasks).Error; err != nil {
+	if err := db.Model(&entity.DtTask{}).Where("status = ? AND remain_count > 0", 1).
+		Order("is_top DESC, created_at DESC").Limit(20).Find(&normalTasks).Error; err != nil {
 		if isColumnNotExistError(err) {
 			// 字段不存在,使用降级查询(不按置顶排序)
-			query.Order("created_at DESC").Limit(20).Find(&normalTasks)
+			db.Model(&entity.DtTask{}).Where("status = ? AND remain_count > 0", 1).
+				Order("created_at DESC").Limit(20).Find(&normalTasks)
 		}
 	}
 
@@ -106,13 +108,14 @@ func (s *Server) HomeHall(c *gin.Context) {
 	paging.Computer()
 
 	// 尝试使用 is_top 和 is_recommend 字段,如果字段不存在则降级处理
+	fallbackQuery := db.Model(&entity.DtTask{}).Where("status = ? AND remain_count > 0", 1)
 	if err := query.Order("is_top DESC, is_recommend DESC, created_at DESC").
 		Offset(int(paging.Start)).
 		Limit(int(paging.Size)).
 		Find(&tasks).Error; err != nil {
 		if isColumnNotExistError(err) {
-			// 字段不存在,使用降级查询(不按置顶和推荐排序)
-			query.Order("created_at DESC").
+			// 字段不存在,使用全新查询(不按置顶和推荐排序)
+			fallbackQuery.Order("created_at DESC").
 				Offset(int(paging.Start)).
 				Limit(int(paging.Size)).
 				Find(&tasks)
@@ -169,13 +172,20 @@ func (s *Server) TaskList(c *gin.Context) {
 
 	tasks := make([]*entity.DtTask, 0)
 	// 尝试使用 is_top 和 is_recommend 字段,如果字段不存在则降级处理
+	fallbackQuery := db.Model(&entity.DtTask{}).Where("status = ? AND remain_count > 0", 1)
+	if categoryId > 0 {
+		fallbackQuery = fallbackQuery.Where("category_id = ?", categoryId)
+	}
+	if keyword != "" {
+		fallbackQuery = fallbackQuery.Where("title LIKE ?", "%"+keyword+"%")
+	}
 	if err := query.Order("is_top DESC, is_recommend DESC, created_at DESC").
 		Offset(int(paging.Start)).
 		Limit(int(paging.Size)).
 		Find(&tasks).Error; err != nil {
 		if isColumnNotExistError(err) {
-			// 字段不存在,使用降级查询(不按置顶和推荐排序)
-			query.Order("created_at DESC").
+			// 字段不存在,使用全新查询(不按置顶和推荐排序)
+			fallbackQuery.Order("created_at DESC").
 				Offset(int(paging.Start)).
 				Limit(int(paging.Size)).
 				Find(&tasks)

+ 82 - 28
apis/daytask/notice.go

@@ -45,6 +45,27 @@ func (s *Server) NoticeList(c *gin.Context) {
 		Limit(int(paging.Size)).
 		Scan(&notices)
 
+	// 对于全局通知(user_id=0),批量查询已读状态(避免N+1)
+	unreadIds := make([]int64, 0)
+	for _, n := range notices {
+		if n.IsRead == 0 {
+			unreadIds = append(unreadIds, n.Id)
+		}
+	}
+	if len(unreadIds) > 0 {
+		var readRecords []entity.DtNoticeRead
+		db.Where("user_id = ? AND notice_id IN ?", userId, unreadIds).Find(&readRecords)
+		readMap := make(map[int64]bool)
+		for _, r := range readRecords {
+			readMap[r.NoticeId] = true
+		}
+		for _, n := range notices {
+			if n.IsRead == 0 && readMap[n.Id] {
+				n.IsRead = 1
+			}
+		}
+	}
+
 	ctx.OK(gin.H{
 		"list":   notices,
 		"paging": paging,
@@ -66,9 +87,27 @@ func (s *Server) NoticeRead(c *gin.Context) {
 		return
 	}
 
-	db.Model(&entity.DtNotice{}).
-		Where("id = ? AND (user_id = ? OR user_id = 0)", req.Id, userId).
-		Update("is_read", 1)
+	// 查询通知
+	var notice entity.DtNotice
+	if err := db.Where("id = ?", req.Id).First(&notice).Error; err != nil {
+		ctx.Fail("notice_not_found")
+		return
+	}
+
+	if notice.UserId == 0 {
+		// 全局通知:写入关联表记录已读
+		readRecord := &entity.DtNoticeRead{
+			UserId:   userId,
+			NoticeId: req.Id,
+		}
+		db.Where("user_id = ? AND notice_id = ?", userId, req.Id).
+			FirstOrCreate(readRecord)
+	} else {
+		// 个人通知:直接更新
+		db.Model(&entity.DtNotice{}).
+			Where("id = ? AND user_id = ?", req.Id, userId).
+			Update("is_read", 1)
+	}
 
 	ctx.OK(gin.H{})
 }
@@ -81,14 +120,32 @@ func (s *Server) NoticeReadAll(c *gin.Context) {
 
 	noticeType := ctx.QueryString("type", "")
 
-	query := db.Model(&entity.DtNotice{}).
-		Where("(user_id = ? OR user_id = 0) AND is_read = ?", userId, 0)
+	// 1. 更新个人通知
+	personalQuery := db.Model(&entity.DtNotice{}).
+		Where("user_id = ? AND is_read = ?", userId, 0)
+	if noticeType != "" {
+		personalQuery = personalQuery.Where("type = ?", noticeType)
+	}
+	personalQuery.Update("is_read", 1)
 
+	// 2. 全局通知:查出未读的全局通知,批量写入关联表
+	globalQuery := db.Model(&entity.DtNotice{}).
+		Where("user_id = 0 AND status = 1")
 	if noticeType != "" {
-		query = query.Where("type = ?", noticeType)
+		globalQuery = globalQuery.Where("type = ?", noticeType)
 	}
 
-	query.Update("is_read", 1)
+	var globalNotices []entity.DtNotice
+	globalQuery.Select("id").Find(&globalNotices)
+
+	for _, n := range globalNotices {
+		readRecord := &entity.DtNoticeRead{
+			UserId:   userId,
+			NoticeId: n.Id,
+		}
+		db.Where("user_id = ? AND notice_id = ?", userId, n.Id).
+			FirstOrCreate(readRecord)
+	}
 
 	ctx.OK(gin.H{})
 }
@@ -99,29 +156,26 @@ func (s *Server) NoticeUnread(c *gin.Context) {
 	db := s.DB()
 	userId := ctx.UserId()
 
-	// 统计各类型未读数量
-	type UnreadCount struct {
-		Type  string `json:"type"`
-		Count int64  `json:"count"`
-	}
+	// 个人未读通知数
+	var personalCount int64
+	db.Model(&entity.DtNotice{}).
+		Where("user_id = ? AND is_read = ? AND status = ?", userId, 0, 1).
+		Count(&personalCount)
 
-	var counts []UnreadCount
+	// 全局未读通知数(排除已在关联表中标记已读的)
+	var globalCount int64
 	db.Model(&entity.DtNotice{}).
-		Select("type, COUNT(*) as count").
-		Where("(user_id = ? OR user_id = 0) AND is_read = ? AND status = ?", userId, 0, 1).
-		Group("type").
-		Scan(&counts)
-
-	// 构建返回结果
-	result := make(map[string]int64)
-	var total int64 = 0
-	for _, c := range counts {
-		result[c.Type] = c.Count
-		total += c.Count
-	}
-	result["total"] = total
+		Where("user_id = 0 AND status = 1").
+		Where("id NOT IN (SELECT notice_id FROM dt_notice_read WHERE user_id = ?)", userId).
+		Count(&globalCount)
+
+	total := personalCount + globalCount
 
-	ctx.OK(result)
+	ctx.OK(gin.H{
+		"personal": personalCount,
+		"global":   globalCount,
+		"total":    total,
+	})
 }
 
 // NoticeDelete 删除消息
@@ -139,7 +193,7 @@ func (s *Server) NoticeDelete(c *gin.Context) {
 		return
 	}
 
-	// 只能删除自己的消息
+	// 只能删除自己的消息(全局通知不可删除)
 	db.Where("id = ? AND user_id = ?", req.Id, userId).Delete(&entity.DtNotice{})
 
 	ctx.OK(gin.H{})

+ 41 - 29
apis/daytask/rank.go

@@ -52,21 +52,27 @@ func (s *Server) RankList(c *gin.Context) {
 		// 任务完成数量排行
 		db.Raw(`
 			SELECT
-				@rank := @rank + 1 as `+"`rank`"+`,
-				u.id as user_id,
-				u.uid,
-				u.nickname,
-				u.avatar,
-				COUNT(t.id) as count
-			FROM dt_user u
-			LEFT JOIN dt_user_task t ON u.id = t.user_id
-				AND t.status = ?
-				AND `+timeCondition+`
-			CROSS JOIN (SELECT @rank := 0) r
-			WHERE u.status = 1
-			GROUP BY u.id
-			HAVING count > 0
-			ORDER BY count DESC
+				ROW_NUMBER() OVER(ORDER BY cnt DESC) as `+"`rank`"+`,
+				user_id,
+				uid,
+				nickname,
+				avatar,
+				cnt as count
+			FROM (
+				SELECT
+					u.id as user_id,
+					u.uid,
+					u.nickname,
+					u.avatar,
+					COUNT(t.id) as cnt
+				FROM dt_user u
+				INNER JOIN dt_user_task t ON u.id = t.user_id
+					AND t.status = ?
+					AND `+timeCondition+`
+				WHERE u.status = 1
+				GROUP BY u.id, u.uid, u.nickname, u.avatar
+				HAVING cnt > 0
+			) ranked
 			LIMIT ?
 		`, entity.UserTaskStatusCompleted, limit).Scan(&ranks)
 
@@ -74,20 +80,26 @@ func (s *Server) RankList(c *gin.Context) {
 		// 邀请人数排行
 		db.Raw(`
 			SELECT
-				@rank := @rank + 1 as `+"`rank`"+`,
-				u.id as user_id,
-				u.uid,
-				u.nickname,
-				u.avatar,
-				COUNT(t.id) as count
-			FROM dt_user u
-			LEFT JOIN dt_user t ON u.id = t.parent_id
-				AND `+timeCondition+`
-			CROSS JOIN (SELECT @rank := 0) r
-			WHERE u.status = 1
-			GROUP BY u.id
-			HAVING count > 0
-			ORDER BY count DESC
+				ROW_NUMBER() OVER(ORDER BY cnt DESC) as `+"`rank`"+`,
+				user_id,
+				uid,
+				nickname,
+				avatar,
+				cnt as count
+			FROM (
+				SELECT
+					u.id as user_id,
+					u.uid,
+					u.nickname,
+					u.avatar,
+					COUNT(t.id) as cnt
+				FROM dt_user u
+				INNER JOIN dt_user t ON u.id = t.parent_id
+					AND `+timeCondition+`
+				WHERE u.status = 1
+				GROUP BY u.id, u.uid, u.nickname, u.avatar
+				HAVING cnt > 0
+			) ranked
 			LIMIT ?
 		`, limit).Scan(&ranks)
 

+ 13 - 8
apis/daytask/sign.go

@@ -115,11 +115,11 @@ func (s *Server) SignDo(c *gin.Context) {
 		return
 	}
 
-	// 更新用户签到信息和余额
+	// 更新用户签到信息和余额(原子操作)
 	updates := map[string]interface{}{
 		"continuous_sign_days": continuousDays,
-		"total_sign_days":      user.TotalSignDays + 1,
-		"balance":              user.Balance + reward,
+		"total_sign_days":      db.Raw("total_sign_days + 1"),
+		"balance":              db.Raw("balance + ?", reward),
 	}
 	if err := tx.Model(&entity.DtUser{}).Where("id = ?", userId).Updates(updates).Error; err != nil {
 		tx.Rollback()
@@ -127,13 +127,18 @@ func (s *Server) SignDo(c *gin.Context) {
 		return
 	}
 
+	// 重新查询更新后的余额
+	var updatedUser entity.DtUser
+	tx.Where("id = ?", userId).First(&updatedUser)
+
 	// 记录余额变动
 	balanceLog := &entity.DtBalanceLog{
-		UserId:       userId,
-		Type:         entity.BalanceLogTypeSignReward,
-		Amount:       reward,
-		AfterBalance: user.Balance + reward,
-		Remark:       "签到奖励",
+		UserId:        userId,
+		Type:          entity.BalanceLogTypeSignReward,
+		Amount:        reward,
+		BeforeBalance: updatedUser.Balance - reward,
+		AfterBalance:  updatedUser.Balance,
+		Remark:        ctx.I18n("sign_reward"),
 	}
 	if err := tx.Create(balanceLog).Error; err != nil {
 		tx.Rollback()

+ 45 - 13
apis/daytask/task.go

@@ -86,6 +86,19 @@ func (s *Server) TaskApply(c *gin.Context) {
 		return
 	}
 
+	// 事务:创建任务记录 + 扣减剩余数量
+	tx := db.Begin()
+
+	// 先原子扣减剩余数量,防止超领
+	result := tx.Model(&entity.DtTask{}).
+		Where("id = ? AND remain_count > 0", req.TaskId).
+		UpdateColumn("remain_count", tx.Raw("remain_count - 1"))
+	if result.Error != nil || result.RowsAffected == 0 {
+		tx.Rollback()
+		ctx.Fail("task_sold_out")
+		return
+	}
+
 	// 创建任务记录
 	userTask := &entity.DtUserTask{
 		UserId:       userId,
@@ -95,15 +108,13 @@ func (s *Server) TaskApply(c *gin.Context) {
 		RewardAmount: task.RewardAmount,
 		Status:       entity.UserTaskStatusPending,
 	}
-	if err := db.Create(userTask).Error; err != nil {
+	if err := tx.Create(userTask).Error; err != nil {
+		tx.Rollback()
 		ctx.Fail("claim_failed")
 		return
 	}
 
-	// 更新任务剩余数量(原子操作)
-	db.Model(&entity.DtTask{}).
-		Where("id = ? AND remain_count > 0", req.TaskId).
-		UpdateColumn("remain_count", db.Raw("remain_count - 1"))
+	tx.Commit()
 
 	ctx.OK(userTask)
 }
@@ -191,15 +202,18 @@ func (s *Server) TaskAbandon(c *gin.Context) {
 		return
 	}
 
-	// 更新状态
-	db.Model(&entity.DtUserTask{}).
+	// 事务:更新状态 + 恢复数量
+	tx := db.Begin()
+
+	tx.Model(&entity.DtUserTask{}).
 		Where("id = ?", req.UserTaskId).
 		Update("status", entity.UserTaskStatusAbandoned)
 
-	// 恢复任务数量
-	db.Model(&entity.DtTask{}).
+	tx.Model(&entity.DtTask{}).
 		Where("id = ?", userTask.TaskId).
-		Update("remain_count", db.Raw("remain_count + 1"))
+		UpdateColumn("remain_count", tx.Raw("remain_count + 1"))
+
+	tx.Commit()
 
 	ctx.OK(nil)
 }
@@ -231,16 +245,34 @@ func (s *Server) TaskMy(c *gin.Context) {
 		Limit(int(paging.Size)).
 		Find(&userTasks)
 
-	// 获取关联的任务信息
+	// 批量获取关联的任务信息(避免N+1查询)
 	type TaskWithInfo struct {
 		*entity.DtUserTask
 		Task *entity.DtTask `json:"task"`
 	}
 
+	// 收集所有 taskId
+	taskIds := make([]int64, 0, len(userTasks))
+	for _, ut := range userTasks {
+		taskIds = append(taskIds, ut.TaskId)
+	}
+
+	// 一次性查询所有任务
+	taskMap := make(map[int64]*entity.DtTask)
+	if len(taskIds) > 0 {
+		tasks := make([]*entity.DtTask, 0)
+		db.Where("id IN ?", taskIds).Find(&tasks)
+		for _, t := range tasks {
+			taskMap[t.Id] = t
+		}
+	}
+
 	result := make([]*TaskWithInfo, 0)
 	for _, ut := range userTasks {
-		task := &entity.DtTask{}
-		db.Where("id = ?", ut.TaskId).First(task)
+		task := taskMap[ut.TaskId]
+		if task == nil {
+			task = &entity.DtTask{}
+		}
 		result = append(result, &TaskWithInfo{
 			DtUserTask: ut,
 			Task:       task,

+ 1 - 1
apis/daytask/team.go

@@ -77,7 +77,7 @@ func (s *Server) TeamMembers(c *gin.Context) {
 	}
 
 	members := make([]*MemberInfo, 0)
-	query.Select("id, uid, nickname, avatar, LEFT(phone, 3) || '****' || RIGHT(phone, 4) as phone, created_at").
+	query.Select("id, uid, nickname, avatar, CONCAT(LEFT(phone, 3), '****', RIGHT(phone, 4)) as phone, created_at").
 		Order("created_at DESC").
 		Offset(int(paging.Start)).
 		Limit(int(paging.Size)).

+ 13 - 8
apis/daytask/user.go

@@ -3,6 +3,7 @@ package daytask
 import (
 	"app/commons/model/entity"
 	"github.com/gin-gonic/gin"
+	"golang.org/x/crypto/bcrypt"
 )
 
 // UserInfo 用户信息
@@ -96,15 +97,19 @@ func (s *Server) UserPassword(c *gin.Context) {
 		return
 	}
 
-	// TODO: 验证旧密码 (需要加密比对)
-	// if !verifyPassword(req.OldPassword, user.Password) {
-	//     ctx.Fail("old_password_error")
-	//     return
-	// }
+	// 验证旧密码
+	if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.OldPassword)); err != nil {
+		ctx.Fail("old_password_error")
+		return
+	}
 
-	// TODO: 加密新密码
-	// newPasswordHash := hashPassword(req.NewPassword)
-	db.Model(&entity.DtUser{}).Where("id = ?", userId).Update("password", req.NewPassword)
+	// 加密新密码
+	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
+	if err != nil {
+		ctx.Fail("system_error")
+		return
+	}
+	db.Model(&entity.DtUser{}).Where("id = ?", userId).Update("password", string(hashedPassword))
 
 	ctx.OK(nil)
 }

+ 16 - 6
apis/daytask/wallet.go

@@ -138,14 +138,24 @@ func (s *Server) WalletWithdraw(c *gin.Context) {
 	// 开始事务
 	tx := db.Begin()
 
-	// 扣除余额
-	if err := tx.Model(&entity.DtUser{}).
+	// 扣除余额(原子操作 + 检查影响行数)
+	result := tx.Model(&entity.DtUser{}).
 		Where("id = ? AND balance >= ?", userId, req.Amount).
-		Update("balance", db.Raw("balance - ?", req.Amount)).Error; err != nil {
+		Update("balance", tx.Raw("balance - ?", req.Amount))
+	if result.Error != nil {
 		tx.Rollback()
 		ctx.Fail("deduct_balance_failed")
 		return
 	}
+	if result.RowsAffected == 0 {
+		tx.Rollback()
+		ctx.Fail("balance_not_enough")
+		return
+	}
+
+	// 重新查询扣款后的余额
+	var updatedUser entity.DtUser
+	tx.Where("id = ?", userId).First(&updatedUser)
 
 	// 创建提现订单
 	order := &entity.DtWithdrawOrder{
@@ -170,10 +180,10 @@ func (s *Server) WalletWithdraw(c *gin.Context) {
 		UserId:        userId,
 		Type:          entity.BalanceLogTypeWithdraw,
 		Amount:        -req.Amount,
-		BeforeBalance: user.Balance,
-		AfterBalance:  user.Balance - req.Amount,
+		BeforeBalance: updatedUser.Balance + req.Amount,
+		AfterBalance:  updatedUser.Balance,
 		RelatedId:     order.Id,
-		Remark:        fmt.Sprintf("提现申请 %s", orderNo),
+		Remark:        fmt.Sprintf("%s %s", ctx.I18n("withdraw_apply"), orderNo),
 	}
 	tx.Create(balanceLog)
 

+ 1 - 0
cmds/migrate.go

@@ -58,6 +58,7 @@ var allTables = []MigrateTable{
 	entity.NewDtWithdrawOrder(),
 	entity.NewDtUserSign(),
 	entity.NewDtNotice(),
+	&entity.DtNoticeRead{},
 	entity.NewDtBanner(),
 	entity.NewDtMaterial(),
 	entity.NewDtMaterialCategory(),

+ 5 - 0
commons/model/entity/dt_config.go

@@ -41,4 +41,9 @@ const (
 	ConfigKeyZaloAppId         = "zalo_app_id"        // Zalo App ID
 	ConfigKeyZaloSecret        = "zalo_secret"        // Zalo Secret Key
 	ConfigKeyTelegramBotName   = "telegram_bot_name"  // Telegram Bot 用户名
+	ConfigKeySmtpHost          = "smtp_host"          // SMTP服务器地址
+	ConfigKeySmtpPort          = "smtp_port"          // SMTP端口
+	ConfigKeySmtpUser          = "smtp_user"          // SMTP用户名(发件邮箱)
+	ConfigKeySmtpPass          = "smtp_pass"          // SMTP密码/授权码
+	ConfigKeySmtpFrom          = "smtp_from"          // 发件人名称
 )

+ 16 - 0
commons/model/entity/dt_notice_read.go

@@ -0,0 +1,16 @@
+package entity
+
+// DtNoticeRead 用户通知已读记录表(解决全局通知 user_id=0 的已读状态问题)
+type DtNoticeRead struct {
+	MysqlBaseModel
+	UserId   int64 `json:"userId" gorm:"uniqueIndex:uk_user_notice;comment:用户ID"`
+	NoticeId int64 `json:"noticeId" gorm:"uniqueIndex:uk_user_notice;comment:通知ID"`
+}
+
+func (*DtNoticeRead) TableName() string {
+	return "dt_notice_read"
+}
+
+func (*DtNoticeRead) Comment() string {
+	return "用户通知已读记录表"
+}

+ 1 - 0
commons/model/entity/dt_user.go

@@ -5,6 +5,7 @@ type DtUser struct {
 	MysqlFullModel
 	Uid              string  `json:"uid" gorm:"type:varchar(32);uniqueIndex:uk_uid;comment:用户UID"`
 	Phone            string  `json:"phone" gorm:"type:varchar(32);uniqueIndex:uk_phone;comment:手机号"`
+	Email            string  `json:"email" gorm:"type:varchar(128);index:idx_email;comment:邮箱"`
 	Password         string  `json:"-" gorm:"type:varchar(128);comment:登录密码"`
 	PayPassword      string  `json:"-" gorm:"type:varchar(128);comment:支付密码"`
 	Nickname         string  `json:"nickname" gorm:"type:varchar(64);comment:昵称"`