sub2api/backend/internal/service/oauth_refresh_api.go

260 lines
8.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log/slog"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
type OAuthRefreshExecutor interface {
TokenRefresher
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
CacheKey(account *Account) string
}
const defaultRefreshLockTTL = 60 * time.Second
// OAuthRefreshResult 统一刷新结果
type OAuthRefreshResult struct {
Refreshed bool // 实际执行了刷新
NewCredentials map[string]any // 刷新后的 credentialsnil 表示未刷新)
Account *Account // 从 DB 重新读取的最新 account
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
}
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、进程内去重singleflight、DB 重读、已刷新检查、竞争恢复等通用逻辑
//
// 双层去重设计:
// 1. 进程内 singleflight合并同一 cacheKey 的并发调用(避免 100 个 goroutine
// 都去抢同一把分布式锁、都重读一次 DB
// 2. 跨进程分布式锁Redis保证集群范围内只有一个 worker 真正发起 OAuth
// 刷新请求。
//
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT跨进程锁仍是必需的
// singleflight 解决不了多 pod 同时刷新。
type OAuthRefreshAPI struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无分布式锁
lockTTL time.Duration
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
sf singleflight.Group
}
// NewOAuthRefreshAPI 创建统一刷新 API
// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI {
ttl := defaultRefreshLockTTL
if len(lockTTL) > 0 && lockTTL[0] > 0 {
ttl = lockTTL[0]
}
return &OAuthRefreshAPI{
accountRepo: accountRepo,
tokenCache: tokenCache,
lockTTL: ttl,
}
}
// getLocalLock 返回指定 cacheKey 的进程内互斥锁
func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
actual, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{})
mu, ok := actual.(*sync.Mutex)
if !ok {
mu = &sync.Mutex{}
api.localLocks.Store(cacheKey, mu)
}
return mu
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
//
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
//
// 流程:
// 1. singleflight 合并同 cacheKey 并发调用
// 2. 获取分布式锁(跨进程)
// 3. 从 DB 重读最新 account防止使用过时的 refresh_token
// 4. 二次检查是否仍需刷新
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 6. 设置 _token_version + 更新 DB
// 7. 释放锁
func (api *OAuthRefreshAPI) RefreshIfNeeded(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account)
// singleflight key 同时区分 cacheKey 和 refreshWindow
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
sfKey := cacheKey + "|" + refreshWindow.String()
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
})
if err != nil {
return nil, err
}
result, _ := v.(*OAuthRefreshResult)
return result, nil
}
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
func (api *OAuthRefreshAPI) refreshOnce(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
cacheKey string,
) (*OAuthRefreshResult, error) {
// 1. 获取分布式锁
lockAcquired := false
if api.tokenCache != nil {
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL)
if lockErr != nil {
// Redis 错误,降级为无锁刷新(进程内互斥锁仍生效)
slog.Warn("oauth_refresh_lock_failed_degraded",
"account_id", account.ID,
"cache_key", cacheKey,
"error", lockErr,
)
} else if !acquired {
// 锁被其他 worker 持有
return &OAuthRefreshResult{LockHeld: true}, nil
} else {
lockAcquired = true
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
// 2. 从 DB 重读最新 account锁保护下确保使用最新的 refresh_token
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
if err != nil {
slog.Warn("oauth_refresh_db_reread_failed",
"account_id", account.ID,
"error", err,
)
// 降级使用传入的 account
freshAccount = account
} else if freshAccount == nil {
freshAccount = account
}
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
return &OAuthRefreshResult{
Account: freshAccount,
}, nil
}
// 4. 执行平台特定刷新逻辑
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
if refreshErr != nil {
// 竞争恢复invalid_grant 可能是另一个 worker 已消费了旧 refresh_token
// 重新读取 DB如果 refresh_token 已更新则说明是竞争,返回成功
if isInvalidGrantError(refreshErr) {
if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered {
slog.Info("oauth_refresh_race_recovered",
"account_id", freshAccount.ID,
"platform", freshAccount.Platform,
)
return &OAuthRefreshResult{
Account: recoveredAccount,
}, nil
}
}
return nil, refreshErr
}
// 5. 设置版本号 + 更新 DB
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID,
"error", updateErr,
)
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
}
}
_ = lockAcquired // suppress unused warning when tokenCache is nil
return &OAuthRefreshResult{
Refreshed: true,
NewCredentials: newCredentials,
Account: freshAccount,
}, nil
}
// isInvalidGrantError 检查错误是否为 invalid_grant
func isInvalidGrantError(err error) bool {
return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant")
}
// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复
// 重新读取 DB如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account
func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) {
if api.accountRepo == nil {
return nil, false
}
reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID)
if err != nil || reReadAccount == nil {
return nil, false
}
usedRT := usedAccount.GetCredential("refresh_token")
currentRT := reReadAccount.GetCredential("refresh_token")
if usedRT == "" || currentRT == "" {
return nil, false
}
// refresh_token 不同 → 另一个 worker 已成功刷新
if usedRT != currentRT {
return reReadAccount, true
}
return nil, false
}
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
if newCreds == nil {
newCreds = make(map[string]any)
}
for k, v := range oldCreds {
if _, exists := newCreds[k]; !exists {
newCreds[k] = v
}
}
return newCreds
}
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
return creds
}