package daytask
import (
"app/apis/middleware"
"app/commons/core/redisclient"
"app/commons/model/entity"
"crypto/md5"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"net/smtp"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
)
// generateInviteCode 生成邀请码
func generateInviteCode() string {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
code := make([]byte, 8)
for i := range code {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
code[i] = charset[n.Int64()]
}
return string(code)
}
// generateUid 生成用户UID(时间戳+随机数,防止并发碰撞)
func generateUid() string {
n, _ := rand.Int(rand.Reader, big.NewInt(9999))
return fmt.Sprintf("DT%d%04d", time.Now().UnixNano()/1000000, n.Int64())
}
// sendSmsBao 调用短信宝发送短信
func sendSmsBao(user, pass, sign, phone, code string) error {
// 短信宝API接口
// 国内: http://api.smsbao.com/sms
// 国际: http://api.smsbao.com/wsms
apiUrl := "http://api.smsbao.com/wsms" // 默认使用国际接口
// 国内手机号使用国内接口
if strings.HasPrefix(phone, "86") || (!strings.HasPrefix(phone, "+") && len(phone) == 11) {
apiUrl = "http://api.smsbao.com/sms"
// 国内手机号去掉区号前缀
if strings.HasPrefix(phone, "86") {
phone = phone[2:]
}
}
// MD5加密密码
h := md5.New()
h.Write([]byte(pass))
passHash := hex.EncodeToString(h.Sum(nil))
// 短信内容
content := fmt.Sprintf("【%s】您的验证码是%s,在5分钟内有效。", sign, code)
// 构建请求参数
params := url.Values{}
params.Set("u", user)
params.Set("p", passHash)
params.Set("m", phone)
params.Set("c", content)
// 发送请求
resp, err := http.Get(apiUrl + "?" + params.Encode())
if err != nil {
return fmt.Errorf("sms request failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
result := strings.TrimSpace(string(body))
// 短信宝返回码: 0=成功, 其他=失败
if result != "0" {
return fmt.Errorf("smsbao error: %s", result)
}
return nil
}
// sendEmail 发送邮件验证码
func sendEmail(host string, port int, user, pass, fromName, toEmail, code string) error {
from := user
subject := "Verification Code"
body := fmt.Sprintf(`
Vitiens
Your verification code is:
%s
This code will expire in 5 minutes.
If you did not request this code, please ignore this email.
`, code)
// 构建邮件内容(固定header顺序)
msg := fmt.Sprintf("From: %s <%s>\r\n", fromName, from)
msg += fmt.Sprintf("To: %s\r\n", toEmail)
msg += fmt.Sprintf("Subject: %s\r\n", subject)
msg += "MIME-Version: 1.0\r\n"
msg += "Content-Type: text/html; charset=UTF-8\r\n"
msg += "\r\n" + body
addr := fmt.Sprintf("%s:%d", host, port)
auth := smtp.PlainAuth("", user, pass, host)
return smtp.SendMail(addr, auth, from, []string{toEmail}, []byte(msg))
}
// SendSmsCode 发送短信/邮件验证码
func (s *Server) SendSmsCode(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Type string `json:"type" binding:"required"` // phone/email
Account string `json:"account" binding:"required"` // 手机号或邮箱
AreaCode string `json:"areaCode"` // 国际区号 如 84, 86
Scene string `json:"scene" binding:"required"` // register/login/reset/bind
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// 验证类型
if req.Type != "phone" && req.Type != "email" {
ctx.Fail("unsupported_type")
return
}
var accountKey string // 用于缓存的key
if req.Type == "phone" {
// 验证手机号格式
phoneRegex := regexp.MustCompile(`^\d{9,15}$`)
if !phoneRegex.MatchString(req.Account) {
ctx.Fail("invalid_phone")
return
}
accountKey = req.Account
if req.AreaCode != "" {
accountKey = req.AreaCode + req.Account
}
} else {
// 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(req.Account) {
ctx.Fail("invalid_email")
return
}
accountKey = req.Account
}
// 防止重复发送
cacheKey := fmt.Sprintf("code:%s:%s", req.Scene, accountKey)
if !ctx.RepeatFilter(cacheKey, 60*time.Second) {
ctx.Fail("code_send_too_fast")
return
}
// 生成验证码
n, _ := rand.Int(rand.Reader, big.NewInt(900000))
code := fmt.Sprintf("%06d", n.Int64()+100000)
if req.Type == "phone" {
// 发送短信验证码
var smsUser, smsPass, smsSign string
var config entity.DtConfig
if err := db.Where("`key` = ?", entity.ConfigKeySmsUser).First(&config).Error; err == nil {
smsUser = config.Value
}
if err := db.Where("`key` = ?", entity.ConfigKeySmsPass).First(&config).Error; err == nil {
smsPass = config.Value
}
if err := db.Where("`key` = ?", entity.ConfigKeySmsSign).First(&config).Error; err == nil {
smsSign = config.Value
}
if smsUser != "" && smsPass != "" {
if err := sendSmsBao(smsUser, smsPass, smsSign, accountKey, code); err != nil {
ctx.Fail("sms_send_failed")
return
}
} else {
fmt.Printf("[DEV] SMS Code for %s: %s\n", accountKey, code)
}
} else {
// 发送邮件验证码
var smtpHost, smtpPortStr, smtpUser, smtpPass, smtpFrom string
// 批量查询所有smtp配置,避免复用同一struct导致GORM主键污染
var smtpConfigs []entity.DtConfig
db.Where("`group` = ?", "smtp").Find(&smtpConfigs)
for _, cfg := range smtpConfigs {
switch cfg.Key {
case entity.ConfigKeySmtpHost:
smtpHost = cfg.Value
case entity.ConfigKeySmtpPort:
smtpPortStr = cfg.Value
case entity.ConfigKeySmtpUser:
smtpUser = cfg.Value
case entity.ConfigKeySmtpPass:
smtpPass = cfg.Value
case entity.ConfigKeySmtpFrom:
smtpFrom = cfg.Value
}
}
if smtpFrom == "" {
smtpFrom = "Vitiens"
}
if smtpHost != "" && smtpUser != "" && smtpPass != "" {
smtpPort, _ := strconv.Atoi(smtpPortStr)
if smtpPort == 0 {
smtpPort = 587
}
if err := sendEmail(smtpHost, smtpPort, smtpUser, smtpPass, smtpFrom, req.Account, code); err != nil {
fmt.Printf("[ERROR] Send email failed: %v\n", err)
ctx.Fail("email_send_failed")
return
}
} else {
fmt.Printf("[DEV] Email Code for %s: %s\n", req.Account, code)
}
}
// 存储验证码到Redis
codeKey2 := fmt.Sprintf("verifycode:%s:%s", req.Scene, accountKey)
redisclient.DefaultClient().Set(c, codeKey2, code, 5*time.Minute)
ctx.OK(gin.H{
"message": "code_sent",
})
}
// Register 用户注册
func (s *Server) Register(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Type string `json:"type" binding:"required"` // phone/email
Account string `json:"account" binding:"required"` // 手机号或邮箱
Code string `json:"code" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
InviteCode string `json:"inviteCode"`
AreaCode string `json:"areaCode"` // 国际区号
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// 验证类型
if req.Type != "phone" && req.Type != "email" {
ctx.Fail("unsupported_type")
return
}
// 账号Key
accountKey := req.Account
if req.Type == "phone" && req.AreaCode != "" {
accountKey = req.AreaCode + req.Account
}
// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
failKey := fmt.Sprintf("code_fail:register:%s", accountKey)
failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
if failCount >= 5 {
ctx.Fail("too_many_attempts")
return
}
// 验证验证码
codeKey := fmt.Sprintf("verifycode:register:%s", accountKey)
storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
if err != nil || storedCode != req.Code {
redisclient.DefaultClient().Incr(c, failKey)
redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
ctx.Fail("invalid_code")
return
}
redisclient.DefaultClient().Del(c, failKey)
// 检查是否已注册
var existUser entity.DtUser
if req.Type == "phone" {
if err := db.Where("phone = ?", accountKey).First(&existUser).Error; err == nil {
ctx.Fail("phone_registered")
return
}
} else {
if err := db.Where("email = ?", accountKey).First(&existUser).Error; err == nil {
ctx.Fail("email_registered")
return
}
}
// 查找邀请人
var parentId int64 = 0
if req.InviteCode != "" {
var inviter entity.DtUser
if err := db.Where("invite_code = ?", req.InviteCode).First(&inviter).Error; err == nil {
parentId = inviter.Id
}
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
ctx.Fail("register_failed")
return
}
// 获取默认等级
var defaultLevel entity.DtUserLevel
db.Where("is_default = ?", 1).First(&defaultLevel)
// 昵称
userPrefix := ctx.I18n("user_prefix")
if userPrefix == "" || userPrefix == "user_prefix" {
userPrefix = "User"
}
// 邮箱取@前部分,手机号取后4位
accountSuffix := req.Account
if req.Type == "email" {
if atIdx := strings.Index(req.Account, "@"); atIdx > 0 {
accountSuffix = req.Account[:atIdx]
}
} else if len(req.Account) >= 4 {
accountSuffix = req.Account[len(req.Account)-4:]
}
nickname := userPrefix + accountSuffix
// 创建用户
user := &entity.DtUser{
Uid: generateUid(),
Password: string(hashedPassword),
Nickname: nickname,
ParentId: parentId,
LevelId: defaultLevel.Id,
InviteCode: generateInviteCode(),
Status: 1,
}
if req.Type == "phone" {
user.Phone = accountKey
} else {
user.Email = accountKey
// 邮箱注册时生成唯一占位手机号,避免uk_phone唯一索引冲突
h := md5.New()
h.Write([]byte(req.Account))
emailHash := hex.EncodeToString(h.Sum(nil))[:16]
user.Phone = fmt.Sprintf("email_%s", emailHash)
}
tx := db.Begin()
if err := tx.Create(user).Error; err != nil {
tx.Rollback()
ctx.Fail("register_failed")
return
}
// 更新邀请人的直推人数和团队人数
if parentId > 0 {
tx.Model(&entity.DtUser{}).Where("id = ?", parentId).
Updates(map[string]interface{}{
"direct_invite_count": tx.Raw("direct_invite_count + 1"),
"team_count": tx.Raw("team_count + 1"),
})
// 更新上级的团队人数(多级)
var parent entity.DtUser
if err := tx.Where("id = ?", parentId).First(&parent).Error; err == nil && parent.ParentId > 0 {
tx.Model(&entity.DtUser{}).Where("id = ?", parent.ParentId).
Update("team_count", tx.Raw("team_count + 1"))
}
}
tx.Commit()
// 删除验证码
redisclient.DefaultClient().Del(c, codeKey)
// 生成Token
token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
if err != nil {
ctx.Fail("register_failed")
return
}
ctx.OK(gin.H{
"token": token,
"user": gin.H{
"id": user.Id,
"uid": user.Uid,
"nickname": user.Nickname,
"avatar": user.Avatar,
"phone": user.Phone,
"email": user.Email,
},
})
}
// LoginByPassword 密码登录
func (s *Server) LoginByPassword(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Type string `json:"type" binding:"required"` // phone/email
Account string `json:"account" binding:"required"` // 手机号或邮箱
Password string `json:"password" binding:"required"`
AreaCode string `json:"areaCode"` // 国际区号
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// 检查登录尝试次数(防暴力破解,5次锁定15分钟)
loginFailKey := fmt.Sprintf("login_fail:%s", req.Account)
loginFailCount, _ := redisclient.DefaultClient().Get(c, loginFailKey).Int()
if loginFailCount >= 5 {
ctx.Fail("too_many_attempts")
return
}
// 查找用户
var user entity.DtUser
if req.Type == "email" {
if err := db.Where("email = ?", req.Account).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
} else {
fullPhone := req.Account
if req.AreaCode != "" {
fullPhone = req.AreaCode + req.Account
}
if err := db.Where("phone = ?", fullPhone).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
}
// 检查用户状态
if user.Status != 1 {
ctx.Fail("user_disabled")
return
}
// 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
redisclient.DefaultClient().Incr(c, loginFailKey)
redisclient.DefaultClient().Expire(c, loginFailKey, 15*time.Minute)
ctx.Fail("invalid_password")
return
}
redisclient.DefaultClient().Del(c, loginFailKey)
// 更新登录时间
db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
Update("last_login_time", time.Now().Unix())
// 生成Token
token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
if err != nil {
ctx.Fail("login_failed")
return
}
ctx.OK(gin.H{
"token": token,
"user": gin.H{
"id": user.Id,
"uid": user.Uid,
"nickname": user.Nickname,
"avatar": user.Avatar,
"phone": user.Phone,
"email": user.Email,
},
})
}
// LoginBySms 短信验证码登录
func (s *Server) LoginBySms(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Phone string `json:"phone" binding:"required"`
Code string `json:"code" binding:"required"`
AreaCode string `json:"areaCode"` // 国际区号
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// 完整手机号(带区号)
fullPhone := req.Phone
if req.AreaCode != "" {
fullPhone = req.AreaCode + req.Phone
}
// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
failKey := fmt.Sprintf("code_fail:login:%s", fullPhone)
failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
if failCount >= 5 {
ctx.Fail("too_many_attempts")
return
}
// 验证短信验证码
codeKey := fmt.Sprintf("verifycode:login:%s", fullPhone)
storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
if err != nil || storedCode != req.Code {
redisclient.DefaultClient().Incr(c, failKey)
redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
ctx.Fail("invalid_code")
return
}
redisclient.DefaultClient().Del(c, failKey)
// 查找用户
var user entity.DtUser
if err := db.Where("phone = ?", fullPhone).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
// 检查用户状态
if user.Status != 1 {
ctx.Fail("user_disabled")
return
}
// 更新登录时间
db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
Update("last_login_time", time.Now().Unix())
// 删除验证码
redisclient.DefaultClient().Del(c, codeKey)
// 生成Token
token, err2 := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
if err2 != nil {
ctx.Fail("login_failed")
return
}
ctx.OK(gin.H{
"token": token,
"user": gin.H{
"id": user.Id,
"uid": user.Uid,
"nickname": user.Nickname,
"avatar": user.Avatar,
"phone": user.Phone,
"email": user.Email,
},
})
}
// ResetPassword 重置密码
func (s *Server) ResetPassword(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Type string `json:"type" binding:"required"` // phone/email
Account string `json:"account" binding:"required"` // 手机号或邮箱
Code string `json:"code" binding:"required"`
NewPassword string `json:"newPassword" binding:"required,min=6"`
AreaCode string `json:"areaCode"` // 国际区号
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// 账号Key
accountKey := req.Account
if req.Type == "phone" && req.AreaCode != "" {
accountKey = req.AreaCode + req.Account
}
// 检查验证码尝试次数(防暴力破解,5次锁定15分钟)
failKey := fmt.Sprintf("code_fail:reset:%s", accountKey)
failCount, _ := redisclient.DefaultClient().Get(c, failKey).Int()
if failCount >= 5 {
ctx.Fail("too_many_attempts")
return
}
// 验证验证码
codeKey := fmt.Sprintf("verifycode:reset:%s", accountKey)
storedCode, err := redisclient.DefaultClient().Get(c, codeKey).Result()
if err != nil || storedCode != req.Code {
redisclient.DefaultClient().Incr(c, failKey)
redisclient.DefaultClient().Expire(c, failKey, 15*time.Minute)
ctx.Fail("invalid_code")
return
}
redisclient.DefaultClient().Del(c, failKey)
// 查找用户
var user entity.DtUser
if req.Type == "email" {
if err := db.Where("email = ?", accountKey).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
} else {
if err := db.Where("phone = ?", accountKey).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
}
// 加密新密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
ctx.Fail("reset_failed")
return
}
// 更新密码
db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
Update("password", string(hashedPassword))
// 删除验证码
redisclient.DefaultClient().Del(c, codeKey)
ctx.OK(gin.H{
"message": "password_reset_success",
})
}
// OAuthLogin OAuth登录(Google/Zalo等)
func (s *Server) OAuthLogin(c *gin.Context) {
ctx := s.FromContext(c)
db := s.DB()
type Req struct {
Provider string `json:"provider" binding:"required"` // google/zalo/telegram
OpenId string `json:"openId" binding:"required"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
InviteCode string `json:"inviteCode"`
Extra string `json:"extra"` // Zalo: code_verifier
}
var req Req
if err := c.ShouldBindJSON(&req); err != nil {
ctx.Fail("invalid_params")
return
}
// TikTok 特殊处理:需要用 code 换取 access_token 和用户信息
if req.Provider == "tiktok" && req.Extra != "" {
tiktokUser, err := s.getTiktokUserInfo(req.OpenId, req.Extra)
if err != nil {
fmt.Printf("TikTok get user info error: %v\n", err)
ctx.Fail("tiktok_auth_failed")
return
}
req.OpenId = tiktokUser.ID
req.Nickname = tiktokUser.Name
req.Avatar = tiktokUser.Picture
}
// Zalo 特殊处理:需要用 code 换取 access_token
if req.Provider == "zalo" && req.Extra != "" {
zaloUser, err := s.getZaloUserInfo(req.OpenId, req.Extra)
if err != nil {
fmt.Printf("Zalo get user info error: %v\n", err)
ctx.Fail("zalo_auth_failed")
return
}
req.OpenId = zaloUser.ID
req.Nickname = zaloUser.Name
req.Avatar = zaloUser.Picture
}
// 查找是否已绑定
var social entity.DtUserSocial
if err := db.Where("platform = ? AND account = ?", req.Provider, req.OpenId).First(&social).Error; err == nil {
// 已绑定,直接登录
var user entity.DtUser
if err := db.Where("id = ?", social.UserId).First(&user).Error; err != nil {
ctx.Fail("user_not_found")
return
}
if user.Status != 1 {
ctx.Fail("user_disabled")
return
}
// 更新登录时间
db.Model(&entity.DtUser{}).Where("id = ?", user.Id).
Update("last_login_time", time.Now().Unix())
token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
if err != nil {
ctx.Fail("login_failed")
return
}
ctx.OK(gin.H{
"token": token,
"user": gin.H{
"id": user.Id,
"uid": user.Uid,
"nickname": user.Nickname,
"avatar": user.Avatar,
"phone": user.Phone,
},
})
return
}
// 未绑定,创建新用户
var parentId int64 = 0
if req.InviteCode != "" {
var inviter entity.DtUser
if err := db.Where("invite_code = ?", req.InviteCode).First(&inviter).Error; err == nil {
parentId = inviter.Id
}
}
// 获取默认等级
var defaultLevel entity.DtUserLevel
db.Where("is_default = ?", 1).First(&defaultLevel)
nickname := req.Nickname
if nickname == "" {
oauthPrefix := ctx.I18n("user_prefix")
if oauthPrefix == "" || oauthPrefix == "user_prefix" {
oauthPrefix = "User"
}
nickname = req.Provider + oauthPrefix
}
// 为OAuth用户生成唯一占位手机号(用MD5避免截断碰撞)
h := md5.New()
h.Write([]byte(req.OpenId))
openIdHash := hex.EncodeToString(h.Sum(nil))[:16]
oauthPhone := fmt.Sprintf("oauth_%s_%s", req.Provider, openIdHash)
user := &entity.DtUser{
Uid: generateUid(),
Phone: oauthPhone,
Nickname: nickname,
Avatar: req.Avatar,
ParentId: parentId,
LevelId: defaultLevel.Id,
InviteCode: generateInviteCode(),
Status: 1,
}
tx := db.Begin()
if err := tx.Create(user).Error; err != nil {
tx.Rollback()
fmt.Printf("OAuth create user error: %v\n", err)
ctx.Fail("login_failed")
return
}
// 创建社交账号绑定
social = entity.DtUserSocial{
UserId: user.Id,
Platform: req.Provider,
Account: req.OpenId,
Nickname: req.Nickname,
Avatar: req.Avatar,
Extra: "{}",
}
if err := tx.Create(&social).Error; err != nil {
tx.Rollback()
fmt.Printf("OAuth create social error: %v\n", err)
ctx.Fail("login_failed")
return
}
// 更新邀请人统计
if parentId > 0 {
tx.Model(&entity.DtUser{}).Where("id = ?", parentId).
Updates(map[string]interface{}{
"direct_invite_count": tx.Raw("direct_invite_count + 1"),
"team_count": tx.Raw("team_count + 1"),
})
}
tx.Commit()
token, err := middleware.GenerateJWT(middleware.Member{ID: user.Id, Uid: user.Uid})
if err != nil {
fmt.Printf("OAuth generate JWT error: %v\n", err)
ctx.Fail("login_failed")
return
}
ctx.OK(gin.H{
"token": token,
"user": gin.H{
"id": user.Id,
"uid": user.Uid,
"nickname": user.Nickname,
"avatar": user.Avatar,
"phone": user.Phone,
"email": user.Email,
},
"isNew": true,
})
}
// ZaloUser Zalo用户信息
type ZaloUser struct {
ID string `json:"id"`
Name string `json:"name"`
Picture string `json:"picture"`
}
// getZaloUserInfo 使用code获取Zalo用户信息
func (s *Server) getZaloUserInfo(code, codeVerifier string) (*ZaloUser, error) {
db := s.DB()
// 获取 Zalo App ID 和 Secret
var appIdConfig, secretConfig entity.DtConfig
if err := db.Where("`key` = ?", entity.ConfigKeyZaloAppId).First(&appIdConfig).Error; err != nil {
return nil, fmt.Errorf("zalo app_id not configured")
}
if err := db.Where("`key` = ?", entity.ConfigKeyZaloSecret).First(&secretConfig).Error; err != nil {
return nil, fmt.Errorf("zalo secret not configured")
}
// 1. 用 code 换取 access_token
tokenUrl := "https://oauth.zaloapp.com/v4/access_token"
data := url.Values{}
data.Set("app_id", appIdConfig.Value)
data.Set("code", code)
data.Set("code_verifier", codeVerifier)
data.Set("grant_type", "authorization_code")
req, _ := http.NewRequest("POST", tokenUrl, strings.NewReader(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("secret_key", secretConfig.Value)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request token failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
fmt.Printf("Zalo token response: %s\n", string(body))
var tokenResp struct {
AccessToken string `json:"access_token"`
ExpiresIn string `json:"expires_in"` // Zalo 返回的是字符串类型
Error int `json:"error"`
Message string `json:"message"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %v", err)
}
if tokenResp.Error != 0 || tokenResp.AccessToken == "" {
return nil, fmt.Errorf("get token failed: %s", tokenResp.Message)
}
// 2. 用 access_token 获取用户信息
userUrl := "https://graph.zalo.me/v2.0/me?fields=id,name,picture"
userReq, _ := http.NewRequest("GET", userUrl, nil)
userReq.Header.Set("access_token", tokenResp.AccessToken)
userResp, err := client.Do(userReq)
if err != nil {
return nil, fmt.Errorf("request user info failed: %v", err)
}
defer userResp.Body.Close()
userBody, _ := io.ReadAll(userResp.Body)
fmt.Printf("Zalo user response: %s\n", string(userBody))
// Zalo API 错误响应格式: {"error": -501, "message": "..."}
// Zalo API 成功响应格式: {"id": "...", "name": "...", "picture": {...}}
var errorResp struct {
Error int `json:"error"`
Message string `json:"message"`
}
if err := json.Unmarshal(userBody, &errorResp); err == nil && errorResp.Error != 0 {
return nil, fmt.Errorf("get user info failed: [%d] %s", errorResp.Error, errorResp.Message)
}
var userInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Picture struct {
Data struct {
URL string `json:"url"`
} `json:"data"`
} `json:"picture"`
}
if err := json.Unmarshal(userBody, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %v", err)
}
if userInfo.ID == "" {
return nil, fmt.Errorf("get user info failed: empty user id")
}
return &ZaloUser{
ID: userInfo.ID,
Name: userInfo.Name,
Picture: userInfo.Picture.Data.URL,
}, nil
}
// TiktokUser TikTok用户信息
type TiktokUser struct {
ID string `json:"id"`
Name string `json:"name"`
Picture string `json:"picture"`
}
// getTiktokUserInfo 使用code获取TikTok用户信息
func (s *Server) getTiktokUserInfo(code, redirectUri string) (*TiktokUser, error) {
db := s.DB()
// 获取 TikTok Client Key 和 Secret
var clientKeyConfig, clientSecretConfig entity.DtConfig
if err := db.Where("`key` = ?", entity.ConfigKeyTiktokClientKey).First(&clientKeyConfig).Error; err != nil {
return nil, fmt.Errorf("tiktok client_key not configured")
}
if err := db.Where("`key` = ?", entity.ConfigKeyTiktokClientSecret).First(&clientSecretConfig).Error; err != nil {
return nil, fmt.Errorf("tiktok client_secret not configured")
}
// 1. 用 code 换取 access_token
tokenUrl := "https://open.tiktokapis.com/v2/oauth/token/"
data := url.Values{}
data.Set("client_key", clientKeyConfig.Value)
data.Set("client_secret", clientSecretConfig.Value)
data.Set("code", code)
data.Set("grant_type", "authorization_code")
data.Set("redirect_uri", redirectUri)
req, _ := http.NewRequest("POST", tokenUrl, strings.NewReader(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request token failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
fmt.Printf("TikTok token response: %s\n", string(body))
var tokenResp struct {
AccessToken string `json:"access_token"`
OpenId string `json:"open_id"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %v", err)
}
if tokenResp.Error != "" || tokenResp.AccessToken == "" {
return nil, fmt.Errorf("get token failed: %s - %s", tokenResp.Error, tokenResp.ErrorDesc)
}
// 2. 用 access_token 获取用户信息
userUrl := "https://open.tiktokapis.com/v2/user/info/?fields=open_id,display_name,avatar_url"
userReq, _ := http.NewRequest("GET", userUrl, nil)
userReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
userResp, err := client.Do(userReq)
if err != nil {
return nil, fmt.Errorf("request user info failed: %v", err)
}
defer userResp.Body.Close()
userBody, _ := io.ReadAll(userResp.Body)
fmt.Printf("TikTok user response: %s\n", string(userBody))
var userInfoResp struct {
Data struct {
User struct {
OpenId string `json:"open_id"`
DisplayName string `json:"display_name"`
AvatarUrl string `json:"avatar_url"`
} `json:"user"`
} `json:"data"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(userBody, &userInfoResp); err != nil {
return nil, fmt.Errorf("parse user info failed: %v", err)
}
if userInfoResp.Error.Code != "" && userInfoResp.Error.Code != "ok" {
return nil, fmt.Errorf("get user info failed: [%s] %s", userInfoResp.Error.Code, userInfoResp.Error.Message)
}
openId := userInfoResp.Data.User.OpenId
if openId == "" {
openId = tokenResp.OpenId
}
if openId == "" {
return nil, fmt.Errorf("get user info failed: empty open_id")
}
return &TiktokUser{
ID: openId,
Name: userInfoResp.Data.User.DisplayName,
Picture: userInfoResp.Data.User.AvatarUrl,
}, nil
}