sub2api/backend/internal/service/oauth_refresh_api.go
win 5c8c15cdb1 feat(refresh,repo): add singleflight to dedupe concurrent token refresh and unschedulable writes
Two anti-thundering-herd improvements:

1. OAuthRefreshAPI.RefreshIfNeeded
   Wrap the existing distributed-lock + DB-reread + executor.Refresh
   pipeline in a per-process singleflight keyed by cacheKey+window.
   Without this, N concurrent goroutines on the same account each pay
   one Redis lock RTT and one DB reread; with it, only the leader pays
   and the rest share the result.

   The refreshWindow is part of the key so a long background-refresh
   window cannot starve a short foreground-refresh window.

2. accountRepository.SetTempUnschedulable
   Wrap the same path (UPDATE + scheduler outbox enqueue + scheduler
   cache sync) in a per-process singleflight keyed by id+until+reason.
   The SQL guard (existing < new) already makes the UPDATE idempotent,
   but N callers still cost N round-trips and N outbox inserts. With
   singleflight, an upstream 401 burst that hits the same account
   collapses to one execution.

Tests cover dedup behavior, key separation by account / refresh window,
and that the SQL exec count drops from N to <=2 (UPDATE + outbox).
2026-04-29 00:43:23 +08:00

199 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"log/slog"
"strconv"
"time"
"golang.org/x/sync/singleflight"
)
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
type OAuthRefreshExecutor interface {
TokenRefresher
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
CacheKey(account *Account) string
}
const refreshLockTTL = 30 * time.Second
// OAuthRefreshResult 统一刷新结果
type OAuthRefreshResult struct {
Refreshed bool // 实际执行了刷新
NewCredentials map[string]any // 刷新后的 credentialsnil 表示未刷新)
Account *Account // 从 DB 重新读取的最新 account
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
}
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
//
// 双层去重设计:
// 1. 进程内 singleflight合并同一 cacheKey 的并发调用(避免 100 个 goroutine
// 都去抢同一把分布式锁、都重读一次 DB
// 2. 跨进程分布式锁Redis保证集群范围内只有一个 worker 真正发起 OAuth
// 刷新请求。
//
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT跨进程锁仍是必需的
// singleflight 解决不了多 pod 同时刷新。
type OAuthRefreshAPI struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无锁
sf singleflight.Group
}
// NewOAuthRefreshAPI 创建统一刷新 API
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
return &OAuthRefreshAPI{
accountRepo: accountRepo,
tokenCache: tokenCache,
}
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
//
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
//
// 流程:
// 1. singleflight 合并同 cacheKey 并发调用
// 2. 获取分布式锁(跨进程)
// 3. 从 DB 重读最新 account防止使用过时的 refresh_token
// 4. 二次检查是否仍需刷新
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 6. 设置 _token_version + 更新 DB
// 7. 释放锁
func (api *OAuthRefreshAPI) RefreshIfNeeded(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account)
// singleflight key 同时区分 cacheKey 和 refreshWindow
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
sfKey := cacheKey + "|" + refreshWindow.String()
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
})
if err != nil {
return nil, err
}
result, _ := v.(*OAuthRefreshResult)
return result, nil
}
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
func (api *OAuthRefreshAPI) refreshOnce(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
cacheKey string,
) (*OAuthRefreshResult, error) {
// 1. 获取分布式锁
lockAcquired := false
if api.tokenCache != nil {
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
if lockErr != nil {
// Redis 错误,降级为无锁刷新
slog.Warn("oauth_refresh_lock_failed_degraded",
"account_id", account.ID,
"cache_key", cacheKey,
"error", lockErr,
)
} else if !acquired {
// 锁被其他 worker 持有
return &OAuthRefreshResult{LockHeld: true}, nil
} else {
lockAcquired = true
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
// 2. 从 DB 重读最新 account锁保护下确保使用最新的 refresh_token
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
if err != nil {
slog.Warn("oauth_refresh_db_reread_failed",
"account_id", account.ID,
"error", err,
)
// 降级使用传入的 account
freshAccount = account
} else if freshAccount == nil {
freshAccount = account
}
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
return &OAuthRefreshResult{
Account: freshAccount,
}, nil
}
// 4. 执行平台特定刷新逻辑
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
if refreshErr != nil {
return nil, refreshErr
}
// 5. 设置版本号 + 更新 DB
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil {
slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID,
"error", updateErr,
)
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
}
}
_ = lockAcquired // suppress unused warning when tokenCache is nil
return &OAuthRefreshResult{
Refreshed: true,
NewCredentials: newCredentials,
Account: freshAccount,
}, nil
}
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
if newCreds == nil {
newCreds = make(map[string]any)
}
for k, v := range oldCreds {
if _, exists := newCreds[k]; !exists {
newCreds[k] = v
}
}
return newCreds
}
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
return creds
}