132 lines
4.1 KiB
Go
132 lines
4.1 KiB
Go
package redis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
)
|
|
|
|
// TokenStore 处理与令牌存储相关的Redis操作
|
|
type TokenStore struct {
|
|
redisClient *redis.Client
|
|
tokenTTL time.Duration // 令牌有效期
|
|
}
|
|
|
|
// NewTokenStore 创建一个新的TokenStore实例
|
|
func NewTokenStore(redisClient *redis.Client, tokenTTL time.Duration) *TokenStore {
|
|
return &TokenStore{
|
|
redisClient: redisClient,
|
|
tokenTTL: tokenTTL,
|
|
}
|
|
}
|
|
|
|
// getUserTokensKey 生成存储用户令牌的Redis键名
|
|
func (ts *TokenStore) getUserTokensKey(username string) string {
|
|
return fmt.Sprintf("user:tokens:%s", username)
|
|
}
|
|
|
|
// getTokenBlacklistKey 生成令牌黑名单的Redis键名
|
|
func (ts *TokenStore) getTokenBlacklistKey(token string) string {
|
|
return fmt.Sprintf("blacklist:%s", token)
|
|
}
|
|
|
|
// InvalidateOldTokens 使指定用户的所有旧令牌失效
|
|
func (ts *TokenStore) InvalidateOldTokens(ctx context.Context, username string) error {
|
|
// 获取用户的所有旧令牌
|
|
oldTokens, err := ts.redisClient.SMembers(ctx, ts.getUserTokensKey(username)).Result()
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
return fmt.Errorf("获取旧令牌失败: %w", err)
|
|
}
|
|
|
|
// 将旧令牌加入黑名单
|
|
for _, oldToken := range oldTokens {
|
|
err := ts.redisClient.Set(ctx, ts.getTokenBlacklistKey(oldToken), "1", ts.tokenTTL).Err()
|
|
if err != nil {
|
|
// 这里可以选择记录日志但不返回错误,避免单个令牌处理失败影响整体流程
|
|
continue
|
|
}
|
|
}
|
|
|
|
// 清除用户的旧令牌列表
|
|
if len(oldTokens) > 0 {
|
|
if err := ts.redisClient.Del(ctx, ts.getUserTokensKey(username)).Err(); err != nil {
|
|
return fmt.Errorf("清除旧令牌列表失败: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// StoreNewToken 存储新令牌
|
|
func (ts *TokenStore) StoreNewToken(ctx context.Context, username, token string) error {
|
|
// 将新令牌添加到用户的令牌列表
|
|
if err := ts.redisClient.SAdd(ctx, ts.getUserTokensKey(username), token).Err(); err != nil {
|
|
return fmt.Errorf("存储新令牌失败: %w", err)
|
|
}
|
|
|
|
// 设置令牌列表的过期时间
|
|
if err := ts.redisClient.Expire(ctx, ts.getUserTokensKey(username), ts.tokenTTL).Err(); err != nil {
|
|
return fmt.Errorf("设置令牌过期时间失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsTokenValid 检查令牌是否有效
|
|
func (ts *TokenStore) IsTokenValid(ctx context.Context, username, token string) (bool, error) {
|
|
// 检查令牌是否在黑名单中
|
|
isBlacklisted, err := ts.redisClient.Exists(ctx, ts.getTokenBlacklistKey(token)).Result()
|
|
if err != nil {
|
|
return false, fmt.Errorf("检查令牌黑名单失败: %w", err)
|
|
}
|
|
if isBlacklisted > 0 {
|
|
return false, nil
|
|
}
|
|
|
|
// 检查令牌是否在用户的有效令牌列表中
|
|
exists, err := ts.redisClient.SIsMember(ctx, ts.getUserTokensKey(username), token).Result()
|
|
if err != nil {
|
|
return false, fmt.Errorf("检查令牌有效性失败: %w", err)
|
|
}
|
|
|
|
return exists, nil
|
|
}
|
|
|
|
// InvalidateToken 使指定令牌失效
|
|
func (ts *TokenStore) InvalidateToken(ctx context.Context, username, token string) error {
|
|
// 将令牌加入黑名单
|
|
if err := ts.redisClient.Set(ctx, ts.getTokenBlacklistKey(token), "1", ts.tokenTTL).Err(); err != nil {
|
|
return fmt.Errorf("将令牌加入黑名单失败: %w", err)
|
|
}
|
|
|
|
// 从用户的有效令牌列表中移除
|
|
if err := ts.redisClient.SRem(ctx, ts.getUserTokensKey(username), token).Err(); err != nil {
|
|
return fmt.Errorf("从有效令牌列表移除令牌失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ReplaceToken 替换令牌(旧令牌失效,新令牌生效)
|
|
func (ts *TokenStore) ReplaceToken(ctx context.Context, username, oldToken, newToken string) error {
|
|
userTokensKey := ts.getUserTokensKey(username)
|
|
|
|
pipe := ts.redisClient.Pipeline()
|
|
// 1. 旧令牌加入黑名单
|
|
pipe.Set(ctx, ts.getTokenBlacklistKey(oldToken), "1", ts.tokenTTL)
|
|
// 2. 移除旧令牌
|
|
pipe.SRem(ctx, userTokensKey, oldToken)
|
|
// 3. 添加新令牌
|
|
pipe.SAdd(ctx, userTokensKey, newToken)
|
|
// 4. 刷新过期时间
|
|
pipe.Expire(ctx, userTokensKey, ts.tokenTTL)
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("替换令牌失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|