Files
IUQT/acquaintances/biz/dal/mysql/user.go

213 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"acquaintances/biz/model"
"acquaintances/biz/model/user"
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/scrypt"
"gorm.io/gorm"
"time"
"github.com/cloudwego/hertz/pkg/common/hlog"
)
func CreateUser(users []*model.User) error {
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
//创建用户信息表记录
err := tx.Create(users).Error
if err != nil {
return err
}
//创建用户关系表记录
var rel model.UserRelations
rel.UserID = users[0].UserID
err = tx.Model(&model.UserRelations{}).Create(&rel).Error
if err != nil {
return err
}
return nil
})
return errTransaction
}
func DeleteUser(userId string) error {
errTransaction := DB.Transaction(func(tx *gorm.DB) error {
err := tx.Model(&model.User{}).Where("user_id = ?", userId).Delete(&model.User{}).Error
if err != nil {
return err
}
return nil
})
return errTransaction
}
func UpdatesUser(user *user.UpdateUserRequest) error {
db := DB.Model(&model.User{})
maps := make(map[string]interface{})
if user.Name != "" {
maps["user_name"] = user.Name
}
if user.Birthday != "" {
maps["birthday"] = user.Birthday
}
if user.Address != "" {
maps["address"] = user.Address
}
if user.Area != 0 {
maps["area"] = user.Area
}
if user.Mobile != "" {
maps["mobile"] = user.Mobile
}
if user.Age != 0 {
maps["age"] = user.Age
}
if user.AvatarImageURL != "" {
maps["avatar_image_url"] = user.AvatarImageURL
}
if user.Introduce != "" {
maps["introduce"] = user.Introduce
}
if user.Gender != 0 {
maps["gender"] = user.Gender
}
hlog.Info(maps)
return db.Where("user_id = ?", user.UserID).Updates(maps).Error
}
// InfoUser 查询单个用户
// 参数用户ID
// 返回:用户信息,错误
func InfoUser(id string) (*model.User, error) {
res := new(model.User)
db := DB.Model(&model.User{}).Where("user_id = ?", id).First(res)
if err := db.Error; err != nil {
return nil, err
}
return res, nil
}
// GetUsersById 批量查询用户信息
// 参数用户ID切片
// 返回:用户信息切片指针,错误
func GetUsersById(ids ...string) ([]*user.UserInfoReq, error) {
db := DB.Model(&model.User{})
// 参数校验
if len(ids) == 0 {
return nil, errors.New("user IDs cannot be empty")
}
query := db.
Where("user_id IN (?)", ids)
var users []*user.UserInfoReq
if err := query.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
// GetUsersByIdMap 批量查询用户信息
// 参数用户ID切片
// 返回用户信息map错误
func GetUsersByIdMap(ids ...string) (map[string]*user.UserInfoReq, error) {
// 参数校验
if len(ids) == 0 {
return nil, errors.New("user IDs cannot be empty")
}
db := DB.Model(&model.User{})
var users []*user.UserInfoReq
// 执行查询
if err := db.Where("user_id IN (?)", ids).Find(&users).Error; err != nil {
return nil, fmt.Errorf("查询用户信息失败: %v", err)
}
// 转换为map以user_id为键
userMap := make(map[string]*user.UserInfoReq, len(users))
for _, u := range users {
// 假设UserInfoReq结构体中有UserID字段
userMap[u.UserID] = u
}
return userMap, nil
}
func FindUser(keyword *string) (*user.UserInfoReq, error) {
db := DB.Model(model.User{})
if keyword != nil && len(*keyword) != 0 {
db = db.Where(DB.Or("user_id = ?", keyword).
Or("user_name = ?", keyword).
Or("mobile = ?", keyword))
}
var res *user.UserInfoReq
if err := db.First(&res).Error; err != nil {
return nil, err
}
return res, nil
}
// CheckLogin 验证登录(适配 Scrypt 加密)
func CheckLogin(username, password string) (*model.User, error) {
var user model.User
// 1. 查询用户(支持 user_id/手机号/用户名登录)
result := DB.Model(&model.User{}).Where(
"user_id = ? OR mobile = ? OR user_name = ?",
username, username, username,
).First(&user)
// 2. 处理查询错误
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("用户名或密码错误") // 模糊错误,防枚举
}
return nil, fmt.Errorf("登录失败:%w", result.Error)
}
// 3. 检查用户状态
//TODO
// 4. 验证密码(核心:使用 Scrypt 算法校验)
// 4.1 解码存储的盐值(加密时用 base64.URLEncoding
salt, err := base64.URLEncoding.DecodeString(user.UserSalt)
if err != nil {
return nil, fmt.Errorf("密码验证失败")
}
// 4.2 用相同算法计算输入密码的哈希
const keyLen = 10
inputHash, err := scrypt.Key([]byte(password), salt, 16384, 8, 1, keyLen)
if err != nil {
return nil, fmt.Errorf("密码验证失败")
}
// 4.3 对比计算结果与存储的密码哈希
storedHash := user.UserPassword
inputHashBase64 := base64.StdEncoding.EncodeToString(inputHash)
if inputHashBase64 != storedHash {
return nil, fmt.Errorf("用户名或密码错误")
}
// 5. 更新最后登录信息
return &user, updateLastLoginInfo(user.ID, getClientIP())
}
// 辅助函数:更新登录信息
func updateLastLoginInfo(userID uint, loginIP string) error {
now := time.Now()
return DB.Model(&model.User{}).Where("id = ?", userID).Updates(map[string]interface{}{
"last_login_ip": loginIP,
"last_login_time": &now,
}).Error
}
// 辅助函数获取客户端IP
func getClientIP() string {
//TODO
return "127.0.0.1"
}