Files
IUQT/acquaintances/biz/dal/mysql/user.go

188 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"acquaintances/biz/model"
"acquaintances/biz/model/user"
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/scrypt"
"gorm.io/gorm"
"time"
"github.com/cloudwego/hertz/pkg/common/hlog"
)
func CreateUser(users []*model.User) error {
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
//创建用户信息表记录
err := tx.Create(users).Error
if err != nil {
return err
}
//创建用户关系表记录
var rel model.UserRelations
rel.UserID = users[0].UserID
err = tx.Model(&model.UserRelations{}).Create(&rel).Error
if err != nil {
return err
}
return nil
})
return errTransaction
}
func DeleteUser(userId string) error {
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
err := tx.Model(&model.User{}).Where("user_id = ?", userId).Delete(&model.User{}).Error
if err != nil {
return err
}
return nil
})
return errTransaction
}
func UpdatesUser(user *user.UpdateUserRequest) error {
db := DB.Model(&model.User{})
maps := make(map[string]interface{})
if user.Name != "" {
maps["user_name"] = user.Name
}
if user.Birthday != "" {
maps["birthday"] = user.Birthday
}
if user.Address != "" {
maps["address"] = user.Address
}
if user.Area != 0 {
maps["area"] = user.Area
}
if user.Mobile != "" {
maps["mobile"] = user.Mobile
}
if user.Age != 0 {
maps["age"] = user.Age
}
if user.AvatarImageURL != "" {
maps["avatar_image_url"] = user.AvatarImageURL
}
if user.Introduce != "" {
maps["introduce"] = user.Introduce
}
if user.Gender != 0 {
maps["gender"] = user.Gender
}
hlog.Info(maps)
return db.Where("user_id = ?", user.UserID).Updates(maps).Error
}
// InfoUser 查询单个用户
// 参数用户ID
// 返回:用户信息,错误
func InfoUser(id string) (*model.User, error) {
res := new(model.User)
db := DB.Model(&model.User{}).Where("user_id = ?", id).First(res)
if err := db.Error; err != nil {
return nil, err
}
return res, nil
}
// GetUsersById 批量查询用户信息
// 参数用户ID切片
// 返回:用户信息切片指针,错误
func GetUsersById(ids ...string) ([]*user.UserInfoReq, error) {
db := DB.Model(&model.User{})
// 参数校验
if len(ids) == 0 {
return nil, errors.New("user IDs cannot be empty")
}
query := db.
Where("user_id IN (?)", ids)
var users []*user.UserInfoReq
if err := query.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func FindUser(keyword *string) (*user.UserInfoReq, error) {
db := DB.Model(model.User{})
if keyword != nil && len(*keyword) != 0 {
db = db.Where(DB.Or("user_id = ?", keyword).
Or("user_name = ?", keyword).
Or("mobile = ?", keyword))
}
var res *user.UserInfoReq
if err := db.First(&res).Error; err != nil {
return nil, err
}
return res, nil
}
// CheckLogin 验证登录(适配 Scrypt 加密)
func CheckLogin(username, password string) (*model.User, error) {
var user model.User
// 1. 查询用户(支持 user_id/手机号/用户名登录)
result := DB.Model(&model.User{}).Where(
"user_id = ? OR mobile = ? OR user_name = ?",
username, username, username,
).First(&user)
// 2. 处理查询错误
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("用户名或密码错误") // 模糊错误,防枚举
}
return nil, fmt.Errorf("登录失败:%w", result.Error)
}
// 3. 检查用户状态
//TODO
// 4. 验证密码(核心:使用 Scrypt 算法校验)
// 4.1 解码存储的盐值(加密时用 base64.URLEncoding
salt, err := base64.URLEncoding.DecodeString(user.UserSalt)
if err != nil {
return nil, fmt.Errorf("密码验证失败")
}
// 4.2 用相同算法计算输入密码的哈希
const keyLen = 10
inputHash, err := scrypt.Key([]byte(password), salt, 16384, 8, 1, keyLen)
if err != nil {
return nil, fmt.Errorf("密码验证失败")
}
// 4.3 对比计算结果与存储的密码哈希
storedHash := user.UserPassword
inputHashBase64 := base64.StdEncoding.EncodeToString(inputHash)
if inputHashBase64 != storedHash {
return nil, fmt.Errorf("用户名或密码错误")
}
// 5. 更新最后登录信息
return &user, updateLastLoginInfo(user.ID, getClientIP())
}
// 辅助函数:更新登录信息
func updateLastLoginInfo(userID uint, loginIP string) error {
now := time.Now()
return DB.Model(&model.User{}).Where("id = ?", userID).Updates(map[string]interface{}{
"last_login_ip": loginIP,
"last_login_time": &now,
}).Error
}
// 辅助函数获取客户端IP
func getClientIP() string {
//TODO
return "127.0.0.1"
}