//go:build unit package service import ( "context" "errors" "net/http" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) type rateLimitAccountRepoStub struct { mockAccountRepoForGemini setErrorCalls int tempCalls int updateCredentialsCalls int lastCredentials map[string]any lastErrorMsg string lastTempReason string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ r.lastErrorMsg = errorMsg return nil } func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ r.lastTempReason = reason return nil } func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { r.updateCredentialsCalls++ r.lastCredentials = cloneCredentials(credentials) return nil } type tokenCacheInvalidatorRecorder struct { accounts []*Account err error } type openAI403CounterCacheStub struct { counts []int64 resetCalls []int64 err error } func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) { if s.err != nil { return 0, s.err } if len(s.counts) == 0 { return 1, nil } count := s.counts[0] s.counts = s.counts[1:] return count, nil } func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { s.resetCalls = append(s.resetCalls, accountID) return nil } func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { r.accounts = append(r.accounts, account) return r.err } func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { t.Run("gemini", func(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 100, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{ "refresh_token": "rt-100", "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": 401, "keywords": []any{"unauthorized"}, "duration_minutes": 30, "description": "custom rule", }, }, }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) }) t.Run("antigravity_401_uses_SetError", func(t *testing.T) { // Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制, // HandleUpstreamError 中走 SetError 路径。 repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 100, Platform: PlatformAntigravity, Type: AccountTypeOAuth, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 0, repo.tempCalls) require.Empty(t, invalidator.accounts) }) } // TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError // OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable。 // 注意:401 handler 不再回写 credentials(避免请求开始时的快照整列覆盖 DB // 把另一个 worker 刚刷新出来的新 refresh_token 回滚为旧值), // 因此 updateCredentialsCalls 应当为 0。 func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 101, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ "refresh_token": "rt-101", }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) require.Equal(t, 0, repo.updateCredentialsCalls) require.Len(t, invalidator.accounts, 1) } func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) } // TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials // 回归测试:确保 401 handler 不再使用请求开始时的 account 快照写回 credentials。 // 原实现会通过 persistAccountCredentials → UpdateCredentials → SetCredentials // 整列覆盖 credentials JSONB,在另一个 worker 刚刷新完 refresh_token 的窄窗口内 // 会把新 refresh_token 回滚为快照中的旧值,导致下一周期拿 invalid_grant 被错误 disable。 func TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials(t *testing.T) { repo := &rateLimitAccountRepoStub{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) account := &Account{ ID: 103, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "token", "refresh_token": "rt-103", }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 0, repo.updateCredentialsCalls, "401 handler must not write credentials back from the request-start snapshot") require.Equal(t, 1, repo.tempCalls, "401 handler should still set temp-unschedulable cooldown") require.Nil(t, repo.lastCredentials, "no credentials should have been persisted") } // 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用, // 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。 func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) { t.Run("openai_no_refresh_token", func(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 2881, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "expired-at", // no refresh_token }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError") require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule") require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible") require.Contains(t, repo.lastErrorMsg, "refresh_token missing") require.Len(t, invalidator.accounts, 1, "cache should still be invalidated") }) t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) { repo := &rateLimitAccountRepoStub{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) account := &Account{ ID: 2882, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "expired-at", "refresh_token": " ", }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 0, repo.tempCalls) }) }