188 lines
4.5 KiB
Go
188 lines
4.5 KiB
Go
package mysql
|
||
|
||
import (
|
||
"acquaintances/biz/model"
|
||
"acquaintances/biz/model/user"
|
||
"encoding/base64"
|
||
"errors"
|
||
"fmt"
|
||
"golang.org/x/crypto/scrypt"
|
||
"gorm.io/gorm"
|
||
"time"
|
||
|
||
"github.com/cloudwego/hertz/pkg/common/hlog"
|
||
)
|
||
|
||
func CreateUser(users []*model.User) error {
|
||
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
|
||
//创建用户信息表记录
|
||
err := tx.Create(users).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
//创建用户关系表记录
|
||
var rel model.UserRelations
|
||
rel.UserID = users[0].UserID
|
||
err = tx.Model(&model.UserRelations{}).Create(&rel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
return errTransaction
|
||
}
|
||
|
||
func DeleteUser(userId string) error {
|
||
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
|
||
err := tx.Model(&model.User{}).Where("user_id = ?", userId).Delete(&model.User{}).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
return errTransaction
|
||
}
|
||
|
||
func UpdatesUser(user *user.UpdateUserRequest) error {
|
||
db := DB.Model(&model.User{})
|
||
maps := make(map[string]interface{})
|
||
if user.Name != "" {
|
||
maps["user_name"] = user.Name
|
||
}
|
||
if user.Birthday != "" {
|
||
maps["birthday"] = user.Birthday
|
||
}
|
||
if user.Address != "" {
|
||
maps["address"] = user.Address
|
||
}
|
||
if user.Area != 0 {
|
||
maps["area"] = user.Area
|
||
}
|
||
if user.Mobile != "" {
|
||
maps["mobile"] = user.Mobile
|
||
}
|
||
if user.Age != 0 {
|
||
maps["age"] = user.Age
|
||
}
|
||
if user.AvatarImageURL != "" {
|
||
maps["avatar_image_url"] = user.AvatarImageURL
|
||
}
|
||
if user.Introduce != "" {
|
||
maps["introduce"] = user.Introduce
|
||
}
|
||
if user.Gender != 0 {
|
||
maps["gender"] = user.Gender
|
||
}
|
||
hlog.Info(maps)
|
||
return db.Where("user_id = ?", user.UserID).Updates(maps).Error
|
||
}
|
||
|
||
// InfoUser 查询单个用户
|
||
// 参数:用户ID
|
||
// 返回:用户信息,错误
|
||
func InfoUser(id string) (*model.User, error) {
|
||
res := new(model.User)
|
||
db := DB.Model(&model.User{}).Where("user_id = ?", id).First(res)
|
||
if err := db.Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return res, nil
|
||
}
|
||
|
||
// GetUsersById 批量查询用户信息
|
||
// 参数:用户ID切片
|
||
// 返回:用户信息切片指针,错误
|
||
func GetUsersById(ids ...string) ([]*user.UserInfoReq, error) {
|
||
db := DB.Model(&model.User{})
|
||
// 参数校验
|
||
if len(ids) == 0 {
|
||
return nil, errors.New("user IDs cannot be empty")
|
||
}
|
||
query := db.
|
||
Where("user_id IN (?)", ids)
|
||
|
||
var users []*user.UserInfoReq
|
||
if err := query.Find(&users).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return users, nil
|
||
}
|
||
|
||
func FindUser(keyword *string) (*user.UserInfoReq, error) {
|
||
db := DB.Model(model.User{})
|
||
if keyword != nil && len(*keyword) != 0 {
|
||
db = db.Where(DB.Or("user_id = ?", keyword).
|
||
Or("user_name = ?", keyword).
|
||
Or("mobile = ?", keyword))
|
||
}
|
||
var res *user.UserInfoReq
|
||
if err := db.First(&res).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return res, nil
|
||
}
|
||
|
||
// CheckLogin 验证登录(适配 Scrypt 加密)
|
||
func CheckLogin(username, password string) (*model.User, error) {
|
||
var user model.User
|
||
|
||
// 1. 查询用户(支持 user_id/手机号/用户名登录)
|
||
result := DB.Model(&model.User{}).Where(
|
||
"user_id = ? OR mobile = ? OR user_name = ?",
|
||
username, username, username,
|
||
).First(&user)
|
||
|
||
// 2. 处理查询错误
|
||
if result.Error != nil {
|
||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||
return nil, fmt.Errorf("用户名或密码错误") // 模糊错误,防枚举
|
||
}
|
||
return nil, fmt.Errorf("登录失败:%w", result.Error)
|
||
}
|
||
|
||
// 3. 检查用户状态
|
||
//TODO
|
||
|
||
// 4. 验证密码(核心:使用 Scrypt 算法校验)
|
||
// 4.1 解码存储的盐值(加密时用 base64.URLEncoding)
|
||
salt, err := base64.URLEncoding.DecodeString(user.UserSalt)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("密码验证失败")
|
||
}
|
||
|
||
// 4.2 用相同算法计算输入密码的哈希
|
||
const keyLen = 10
|
||
inputHash, err := scrypt.Key([]byte(password), salt, 16384, 8, 1, keyLen)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("密码验证失败")
|
||
}
|
||
|
||
// 4.3 对比计算结果与存储的密码哈希
|
||
storedHash := user.UserPassword
|
||
inputHashBase64 := base64.StdEncoding.EncodeToString(inputHash)
|
||
if inputHashBase64 != storedHash {
|
||
return nil, fmt.Errorf("用户名或密码错误")
|
||
}
|
||
|
||
// 5. 更新最后登录信息
|
||
return &user, updateLastLoginInfo(user.ID, getClientIP())
|
||
}
|
||
|
||
// 辅助函数:更新登录信息
|
||
func updateLastLoginInfo(userID uint, loginIP string) error {
|
||
now := time.Now()
|
||
return DB.Model(&model.User{}).Where("id = ?", userID).Updates(map[string]interface{}{
|
||
"last_login_ip": loginIP,
|
||
"last_login_time": &now,
|
||
}).Error
|
||
}
|
||
|
||
// 辅助函数:获取客户端IP
|
||
func getClientIP() string {
|
||
//TODO
|
||
return "127.0.0.1"
|
||
}
|