token_provider 在 expires_at 已过且 refresh_token 缺失时,仅返回 error,未做任何降级。 HandleUpstreamError 的 OAuth 401 分支也只走 10min 冷却,不区分账号是否具备刷新能力。 两条路径相加导致缺 refresh_token 的账号被反复选中、每次都在 token 阶段失败,对用户呈现持续 502。 token_provider.GetAccessToken: 命中"过期且无 refresh_token"时调用 SetError 永久禁用并清缓存, 依赖 background context 避免请求 ctx 提前结束影响落库。 ratelimit_service 401 OAuth 分支:refresh_token 为空时直接 SetError,不再写 expires_at、 不再 SetTempUnschedulable,缓存失效保留。RT 账号路径完全不动。 新增/调整测试覆盖两条路径,旧测试为 RT 路径补足 refresh_token 字段以保留原意图。
242 lines
7.6 KiB
Go
242 lines
7.6 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type rateLimitAccountRepoStub struct {
|
|
mockAccountRepoForGemini
|
|
setErrorCalls int
|
|
tempCalls int
|
|
updateCredentialsCalls int
|
|
lastCredentials map[string]any
|
|
lastErrorMsg string
|
|
lastTempReason string
|
|
}
|
|
|
|
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
|
r.setErrorCalls++
|
|
r.lastErrorMsg = errorMsg
|
|
return nil
|
|
}
|
|
|
|
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
r.tempCalls++
|
|
r.lastTempReason = reason
|
|
return nil
|
|
}
|
|
|
|
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
|
r.updateCredentialsCalls++
|
|
r.lastCredentials = cloneCredentials(credentials)
|
|
return nil
|
|
}
|
|
|
|
type tokenCacheInvalidatorRecorder struct {
|
|
accounts []*Account
|
|
err error
|
|
}
|
|
|
|
type openAI403CounterCacheStub struct {
|
|
counts []int64
|
|
resetCalls []int64
|
|
err error
|
|
}
|
|
|
|
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
|
|
if s.err != nil {
|
|
return 0, s.err
|
|
}
|
|
if len(s.counts) == 0 {
|
|
return 1, nil
|
|
}
|
|
count := s.counts[0]
|
|
s.counts = s.counts[1:]
|
|
return count, nil
|
|
}
|
|
|
|
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
|
|
s.resetCalls = append(s.resetCalls, accountID)
|
|
return nil
|
|
}
|
|
|
|
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
|
r.accounts = append(r.accounts, account)
|
|
return r.err
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
|
t.Run("gemini", func(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 100,
|
|
Platform: PlatformGemini,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"refresh_token": "rt-100",
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": 401,
|
|
"keywords": []any{"unauthorized"},
|
|
"duration_minutes": 30,
|
|
"description": "custom rule",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 0, repo.setErrorCalls)
|
|
require.Equal(t, 1, repo.tempCalls)
|
|
require.Len(t, invalidator.accounts, 1)
|
|
})
|
|
|
|
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
|
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
|
// HandleUpstreamError 中走 SetError 路径。
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 100,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Equal(t, 0, repo.tempCalls)
|
|
require.Empty(t, invalidator.accounts)
|
|
})
|
|
}
|
|
|
|
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
|
|
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 101,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"refresh_token": "rt-101",
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 0, repo.setErrorCalls)
|
|
require.Equal(t, 1, repo.tempCalls)
|
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
|
require.Len(t, invalidator.accounts, 1)
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 102,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeAPIKey,
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Empty(t, invalidator.accounts)
|
|
}
|
|
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
account := &Account{
|
|
ID: 103,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"access_token": "token",
|
|
"refresh_token": "rt-103",
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.updateCredentialsCalls)
|
|
require.NotEmpty(t, repo.lastCredentials["expires_at"])
|
|
}
|
|
|
|
// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
|
|
// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
|
|
func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
|
|
t.Run("openai_no_refresh_token", func(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
service.SetTokenCacheInvalidator(invalidator)
|
|
account := &Account{
|
|
ID: 2881,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"access_token": "expired-at",
|
|
// no refresh_token
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
|
|
require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
|
|
require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
|
|
require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
|
|
require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
|
|
})
|
|
|
|
t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
|
|
repo := &rateLimitAccountRepoStub{}
|
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
account := &Account{
|
|
ID: 2882,
|
|
Platform: PlatformOpenAI,
|
|
Type: AccountTypeOAuth,
|
|
Credentials: map[string]any{
|
|
"access_token": "expired-at",
|
|
"refresh_token": " ",
|
|
},
|
|
}
|
|
|
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
|
|
require.True(t, shouldDisable)
|
|
require.Equal(t, 1, repo.setErrorCalls)
|
|
require.Equal(t, 0, repo.tempCalls)
|
|
})
|
|
}
|