package service import ( "context" "fmt" "log/slog" "strconv" "time" "golang.org/x/sync/singleflight" ) // OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 // TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁 type OAuthRefreshExecutor interface { TokenRefresher // CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致) CacheKey(account *Account) string } const refreshLockTTL = 30 * time.Second // OAuthRefreshResult 统一刷新结果 type OAuthRefreshResult struct { Refreshed bool // 实际执行了刷新 NewCredentials map[string]any // 刷新后的 credentials(nil 表示未刷新) Account *Account // 从 DB 重新读取的最新 account LockHeld bool // 锁被其他 worker 持有(未执行刷新) } // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 // 封装分布式锁、DB 重读、已刷新检查等通用逻辑 // // 双层去重设计: // 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine // 都去抢同一把分布式锁、都重读一次 DB)。 // 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth // 刷新请求。 // // 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的, // singleflight 解决不了多 pod 同时刷新。 type OAuthRefreshAPI struct { accountRepo AccountRepository tokenCache GeminiTokenCache // 可选,nil = 无锁 sf singleflight.Group } // NewOAuthRefreshAPI 创建统一刷新 API func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { return &OAuthRefreshAPI{ accountRepo: accountRepo, tokenCache: tokenCache, } } // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。 // // 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者" // 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。 // // 流程: // 1. singleflight 合并同 cacheKey 并发调用 // 2. 获取分布式锁(跨进程) // 3. 从 DB 重读最新 account(防止使用过时的 refresh_token) // 4. 二次检查是否仍需刷新 // 5. 调用 executor.Refresh() 执行平台特定刷新逻辑 // 6. 设置 _token_version + 更新 DB // 7. 释放锁 func (api *OAuthRefreshAPI) RefreshIfNeeded( ctx context.Context, account *Account, executor OAuthRefreshExecutor, refreshWindow time.Duration, ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) // singleflight key 同时区分 cacheKey 和 refreshWindow: // 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh, // 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。 sfKey := cacheKey + "|" + refreshWindow.String() v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) { return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey) }) if err != nil { return nil, err } result, _ := v.(*OAuthRefreshResult) return result, nil } // refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。 // 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。 func (api *OAuthRefreshAPI) refreshOnce( ctx context.Context, account *Account, executor OAuthRefreshExecutor, refreshWindow time.Duration, cacheKey string, ) (*OAuthRefreshResult, error) { // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) if lockErr != nil { // Redis 错误,降级为无锁刷新 slog.Warn("oauth_refresh_lock_failed_degraded", "account_id", account.ID, "cache_key", cacheKey, "error", lockErr, ) } else if !acquired { // 锁被其他 worker 持有 return &OAuthRefreshResult{LockHeld: true}, nil } else { lockAcquired = true defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() } } // 2. 从 DB 重读最新 account(锁保护下,确保使用最新的 refresh_token) freshAccount, err := api.accountRepo.GetByID(ctx, account.ID) if err != nil { slog.Warn("oauth_refresh_db_reread_failed", "account_id", account.ID, "error", err, ) // 降级使用传入的 account freshAccount = account } else if freshAccount == nil { freshAccount = account } // 3. 二次检查是否仍需刷新(另一条路径可能已刷新) if !executor.NeedsRefresh(freshAccount, refreshWindow) { return &OAuthRefreshResult{ Account: freshAccount, }, nil } // 4. 执行平台特定刷新逻辑 newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) if refreshErr != nil { return nil, refreshErr } // 5. 设置版本号 + 更新 DB if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil { slog.Error("oauth_refresh_update_failed", "account_id", freshAccount.ID, "error", updateErr, ) return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr) } } _ = lockAcquired // suppress unused warning when tokenCache is nil return &OAuthRefreshResult{ Refreshed: true, NewCredentials: newCredentials, Account: freshAccount, }, nil } // MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { if newCreds == nil { newCreds = make(map[string]any) } for k, v := range oldCreds { if _, exists := newCreds[k]; !exists { newCreds[k] = v } } return newCreds } // BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map // 消除 Claude 平台没有 BuildAccountCredentials 方法的问题 func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any { creds := map[string]any{ "access_token": tokenInfo.AccessToken, "token_type": tokenInfo.TokenType, "expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10), "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), } if tokenInfo.RefreshToken != "" { creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.Scope != "" { creds["scope"] = tokenInfo.Scope } return creds }