fix(openai): 永久禁用缺失 refresh_token 的 OAuth 账号

token_provider 在 expires_at 已过且 refresh_token 缺失时,仅返回 error,未做任何降级。
HandleUpstreamError 的 OAuth 401 分支也只走 10min 冷却,不区分账号是否具备刷新能力。
两条路径相加导致缺 refresh_token 的账号被反复选中、每次都在 token 阶段失败,对用户呈现持续 502。

token_provider.GetAccessToken: 命中"过期且无 refresh_token"时调用 SetError 永久禁用并清缓存,
依赖 background context 避免请求 ctx 提前结束影响落库。
ratelimit_service 401 OAuth 分支:refresh_token 为空时直接 SetError,不再写 expires_at、
不再 SetTempUnschedulable,缓存失效保留。RT 账号路径完全不动。

新增/调整测试覆盖两条路径,旧测试为 RT 路径补足 refresh_token 字段以保留原意图。
This commit is contained in:
name 2026-05-16 19:40:23 +08:00
parent 6e66edbb09
commit bec1e2b697
4 changed files with 135 additions and 2 deletions

View File

@ -154,7 +154,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
if expiresAt != nil && !time.Now().Before(*expiresAt) {
return "", errors.New("openai access_token expired and refresh_token is missing")
const reason = "openai access_token expired and refresh_token is missing"
// 永久故障:缺失 refresh_token 时账号无法自愈,必须立即从调度池剔除,
// 否则会被反复选中、每次都在 token 阶段直接返回错误,对用户呈现持续 502。
p.disableAccountMissingRefreshToken(account, reason)
return "", errors.New(reason)
}
needsRefresh = false
}
@ -261,6 +265,39 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
// disableAccountMissingRefreshToken 在请求路径上发现 OpenAI OAuth 账号
// 凭证已过期且 refresh_token 缺失时,将账号标记为 error 状态。
// 这是一种永久性故障:仅靠后续请求或 TokenRefreshService 不会自愈
// NeedsRefresh 也会因 refresh_token 为空直接跳过),
// 必须主动剔除以避免账号被持续选中导致用户端反复 502。
// 使用 background context 是因为请求 context 可能很快结束。
func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account, reason string) {
if p == nil || p.accountRepo == nil || account == nil {
return
}
bgCtx := context.Background()
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
slog.Warn("openai_token_provider.set_error_failed",
"account_id", account.ID,
"error", err,
)
return
}
if p.tokenCache != nil {
cacheKey := OpenAITokenCacheKey(account)
if err := p.tokenCache.DeleteAccessToken(bgCtx, cacheKey); err != nil {
slog.Warn("openai_token_provider.cache_delete_failed",
"account_id", account.ID,
"error", err,
)
}
}
slog.Warn("openai_token_provider.account_disabled_missing_refresh_token",
"account_id", account.ID,
"reason", reason,
)
}
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
wait := openAILockInitialWait
totalWaitMs := int64(0)

View File

@ -930,3 +930,34 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
}
func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) {
cache := newOpenAITokenCacheStub()
repo := &rateLimitAccountRepoStub{}
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
account := &Account{
ID: 2881,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-access-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
cache.tokens[cacheKey] = "stale-cached-token"
// Force the provider past the cache hit branch.
cache.getErr = errors.New("simulated cache miss")
provider := NewOpenAITokenProvider(repo, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Empty(t, token)
require.Contains(t, err.Error(), "refresh_token is missing")
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
}

View File

@ -209,6 +209,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
// 缺少 refresh_token 的 OAuth 账号无法在冷却期内自愈(后台刷新服务也会跳过),
// 直接走 SetError 永久禁用,避免冷却结束后再被选中产生一发无意义的 502。
if strings.TrimSpace(account.GetCredential("refresh_token")) == "" {
msg := "Authentication failed (401): refresh_token missing, cannot recover"
if upstreamMsg != "" {
msg = "OAuth 401 (no refresh_token): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
break
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if account.Credentials == nil {
account.Credentials = make(map[string]any)

View File

@ -85,6 +85,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "rt-100",
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
@ -138,6 +139,9 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "rt-101",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
@ -175,7 +179,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
"access_token": "token",
"refresh_token": "rt-103",
},
}
@ -185,3 +190,52 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
t.Run("openai_no_refresh_token", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 2881,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-at",
// no refresh_token
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
})
t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 2882,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-at",
"refresh_token": " ",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
})
}