fix(openai): 永久禁用缺失 refresh_token 的 OAuth 账号
token_provider 在 expires_at 已过且 refresh_token 缺失时,仅返回 error,未做任何降级。 HandleUpstreamError 的 OAuth 401 分支也只走 10min 冷却,不区分账号是否具备刷新能力。 两条路径相加导致缺 refresh_token 的账号被反复选中、每次都在 token 阶段失败,对用户呈现持续 502。 token_provider.GetAccessToken: 命中"过期且无 refresh_token"时调用 SetError 永久禁用并清缓存, 依赖 background context 避免请求 ctx 提前结束影响落库。 ratelimit_service 401 OAuth 分支:refresh_token 为空时直接 SetError,不再写 expires_at、 不再 SetTempUnschedulable,缓存失效保留。RT 账号路径完全不动。 新增/调整测试覆盖两条路径,旧测试为 RT 路径补足 refresh_token 字段以保留原意图。
This commit is contained in:
parent
6e66edbb09
commit
bec1e2b697
@ -154,7 +154,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
|
||||||
if expiresAt != nil && !time.Now().Before(*expiresAt) {
|
if expiresAt != nil && !time.Now().Before(*expiresAt) {
|
||||||
return "", errors.New("openai access_token expired and refresh_token is missing")
|
const reason = "openai access_token expired and refresh_token is missing"
|
||||||
|
// 永久故障:缺失 refresh_token 时账号无法自愈,必须立即从调度池剔除,
|
||||||
|
// 否则会被反复选中、每次都在 token 阶段直接返回错误,对用户呈现持续 502。
|
||||||
|
p.disableAccountMissingRefreshToken(account, reason)
|
||||||
|
return "", errors.New(reason)
|
||||||
}
|
}
|
||||||
needsRefresh = false
|
needsRefresh = false
|
||||||
}
|
}
|
||||||
@ -261,6 +265,39 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// disableAccountMissingRefreshToken 在请求路径上发现 OpenAI OAuth 账号
|
||||||
|
// 凭证已过期且 refresh_token 缺失时,将账号标记为 error 状态。
|
||||||
|
// 这是一种永久性故障:仅靠后续请求或 TokenRefreshService 不会自愈
|
||||||
|
// (NeedsRefresh 也会因 refresh_token 为空直接跳过),
|
||||||
|
// 必须主动剔除以避免账号被持续选中导致用户端反复 502。
|
||||||
|
// 使用 background context 是因为请求 context 可能很快结束。
|
||||||
|
func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account, reason string) {
|
||||||
|
if p == nil || p.accountRepo == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bgCtx := context.Background()
|
||||||
|
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
|
||||||
|
slog.Warn("openai_token_provider.set_error_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.tokenCache != nil {
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
if err := p.tokenCache.DeleteAccessToken(bgCtx, cacheKey); err != nil {
|
||||||
|
slog.Warn("openai_token_provider.cache_delete_failed",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"error", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slog.Warn("openai_token_provider.account_disabled_missing_refresh_token",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"reason", reason,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
|
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
|
||||||
wait := openAILockInitialWait
|
wait := openAILockInitialWait
|
||||||
totalWaitMs := int64(0)
|
totalWaitMs := int64(0)
|
||||||
|
|||||||
@ -930,3 +930,34 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
|||||||
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
|
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
|
||||||
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 2881,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "expired-access-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
cache.tokens[cacheKey] = "stale-cached-token"
|
||||||
|
// Force the provider past the cache hit branch.
|
||||||
|
cache.getErr = errors.New("simulated cache miss")
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(repo, cache, nil)
|
||||||
|
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, token)
|
||||||
|
require.Contains(t, err.Error(), "refresh_token is missing")
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
|
||||||
|
}
|
||||||
|
|||||||
@ -209,6 +209,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 缺少 refresh_token 的 OAuth 账号无法在冷却期内自愈(后台刷新服务也会跳过),
|
||||||
|
// 直接走 SetError 永久禁用,避免冷却结束后再被选中产生一发无意义的 502。
|
||||||
|
if strings.TrimSpace(account.GetCredential("refresh_token")) == "" {
|
||||||
|
msg := "Authentication failed (401): refresh_token missing, cannot recover"
|
||||||
|
if upstreamMsg != "" {
|
||||||
|
msg = "OAuth 401 (no refresh_token): " + upstreamMsg
|
||||||
|
}
|
||||||
|
s.handleAuthError(ctx, account, msg)
|
||||||
|
shouldDisable = true
|
||||||
|
break
|
||||||
|
}
|
||||||
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
||||||
if account.Credentials == nil {
|
if account.Credentials == nil {
|
||||||
account.Credentials = make(map[string]any)
|
account.Credentials = make(map[string]any)
|
||||||
|
|||||||
@ -85,6 +85,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
|
|||||||
Platform: PlatformGemini,
|
Platform: PlatformGemini,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
|
"refresh_token": "rt-100",
|
||||||
"temp_unschedulable_enabled": true,
|
"temp_unschedulable_enabled": true,
|
||||||
"temp_unschedulable_rules": []any{
|
"temp_unschedulable_rules": []any{
|
||||||
map[string]any{
|
map[string]any{
|
||||||
@ -138,6 +139,9 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
|
|||||||
ID: 101,
|
ID: 101,
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"refresh_token": "rt-101",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
@ -175,7 +179,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
|
|||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Credentials: map[string]any{
|
Credentials: map[string]any{
|
||||||
"access_token": "token",
|
"access_token": "token",
|
||||||
|
"refresh_token": "rt-103",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,3 +190,52 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
|
|||||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||||
require.NotEmpty(t, repo.lastCredentials["expires_at"])
|
require.NotEmpty(t, repo.lastCredentials["expires_at"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
|
||||||
|
// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
|
||||||
|
func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
|
||||||
|
t.Run("openai_no_refresh_token", func(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
service.SetTokenCacheInvalidator(invalidator)
|
||||||
|
account := &Account{
|
||||||
|
ID: 2881,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "expired-at",
|
||||||
|
// no refresh_token
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
|
||||||
|
require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
|
||||||
|
require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
|
||||||
|
require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
account := &Account{
|
||||||
|
ID: 2882,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "expired-at",
|
||||||
|
"refresh_token": " ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
|
require.Equal(t, 0, repo.tempCalls)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user