sub2api/backend/internal/service/ratelimit_service_401_test.go
fofoj be3613593b test(oauth): update OAuth 401 tests to match new no-write behavior
Two tests in ratelimit_service_401_test.go were encoding the bug behavior
itself:

- OAuth401InvalidatorError asserted updateCredentialsCalls == 1
- OAuth401UsesCredentialsUpdater asserted updateCredentialsCalls == 1
  and lastCredentials["expires_at"] non-empty

Both assertions exercised the exact write-back this PR removes. Update
them to reflect the new contract and guard against regression:

- OAuth401InvalidatorError: assert updateCredentialsCalls == 0
- OAuth401UsesCredentialsUpdater is renamed to
  OAuth401DoesNotOverwriteCredentials with reversed assertions, so it
  now serves as a regression test ensuring the 401 handler never writes
  credentials back from the request-start snapshot.
2026-05-28 20:32:16 +08:00

251 lines
8.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
lastTempReason string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
r.lastTempReason = reason
return nil
}
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
type openAI403CounterCacheStub struct {
counts []int64
resetCalls []int64
err error
}
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
if s.err != nil {
return 0, s.err
}
if len(s.counts) == 0 {
return 1, nil
}
count := s.counts[0]
s.counts = s.counts[1:]
return count, nil
}
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
t.Run("gemini", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "rt-100",
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
}
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable。
// 注意401 handler 不再回写 credentials(避免请求开始时的快照整列覆盖 DB
// 把另一个 worker 刚刷新出来的新 refresh_token 回滚为旧值),
// 因此 updateCredentialsCalls 应当为 0。
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"refresh_token": "rt-101",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
// TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials
// 回归测试:确保 401 handler 不再使用请求开始时的 account 快照写回 credentials。
// 原实现会通过 persistAccountCredentials → UpdateCredentials → SetCredentials
// 整列覆盖 credentials JSONB,在另一个 worker 刚刷新完 refresh_token 的窄窗口内
// 会把新 refresh_token 回滚为快照中的旧值,导致下一周期拿 invalid_grant 被错误 disable。
func TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
"refresh_token": "rt-103",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.updateCredentialsCalls, "401 handler must not write credentials back from the request-start snapshot")
require.Equal(t, 1, repo.tempCalls, "401 handler should still set temp-unschedulable cooldown")
require.Nil(t, repo.lastCredentials, "no credentials should have been persisted")
}
// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
t.Run("openai_no_refresh_token", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 2881,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-at",
// no refresh_token
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
})
t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 2882,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "expired-at",
"refresh_token": " ",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
})
}