Two tests in ratelimit_service_401_test.go were encoding the bug behavior itself: - OAuth401InvalidatorError asserted updateCredentialsCalls == 1 - OAuth401UsesCredentialsUpdater asserted updateCredentialsCalls == 1 and lastCredentials["expires_at"] non-empty Both assertions exercised the exact write-back this PR removes. Update them to reflect the new contract and guard against regression: - OAuth401InvalidatorError: assert updateCredentialsCalls == 0 - OAuth401UsesCredentialsUpdater is renamed to OAuth401DoesNotOverwriteCredentials with reversed assertions, so it now serves as a regression test ensuring the 401 handler never writes credentials back from the request-start snapshot.
251 lines
8.5 KiB
Go
251 lines
8.5 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"net/http"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
type rateLimitAccountRepoStub struct {
|
||
mockAccountRepoForGemini
|
||
setErrorCalls int
|
||
tempCalls int
|
||
updateCredentialsCalls int
|
||
lastCredentials map[string]any
|
||
lastErrorMsg string
|
||
lastTempReason string
|
||
}
|
||
|
||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||
r.setErrorCalls++
|
||
r.lastErrorMsg = errorMsg
|
||
return nil
|
||
}
|
||
|
||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||
r.tempCalls++
|
||
r.lastTempReason = reason
|
||
return nil
|
||
}
|
||
|
||
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||
r.updateCredentialsCalls++
|
||
r.lastCredentials = cloneCredentials(credentials)
|
||
return nil
|
||
}
|
||
|
||
type tokenCacheInvalidatorRecorder struct {
|
||
accounts []*Account
|
||
err error
|
||
}
|
||
|
||
type openAI403CounterCacheStub struct {
|
||
counts []int64
|
||
resetCalls []int64
|
||
err error
|
||
}
|
||
|
||
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
|
||
if s.err != nil {
|
||
return 0, s.err
|
||
}
|
||
if len(s.counts) == 0 {
|
||
return 1, nil
|
||
}
|
||
count := s.counts[0]
|
||
s.counts = s.counts[1:]
|
||
return count, nil
|
||
}
|
||
|
||
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
|
||
s.resetCalls = append(s.resetCalls, accountID)
|
||
return nil
|
||
}
|
||
|
||
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||
r.accounts = append(r.accounts, account)
|
||
return r.err
|
||
}
|
||
|
||
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
||
t.Run("gemini", func(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
service.SetTokenCacheInvalidator(invalidator)
|
||
account := &Account{
|
||
ID: 100,
|
||
Platform: PlatformGemini,
|
||
Type: AccountTypeOAuth,
|
||
Credentials: map[string]any{
|
||
"refresh_token": "rt-100",
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": 401,
|
||
"keywords": []any{"unauthorized"},
|
||
"duration_minutes": 30,
|
||
"description": "custom rule",
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 0, repo.setErrorCalls)
|
||
require.Equal(t, 1, repo.tempCalls)
|
||
require.Len(t, invalidator.accounts, 1)
|
||
})
|
||
|
||
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
||
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
||
// HandleUpstreamError 中走 SetError 路径。
|
||
repo := &rateLimitAccountRepoStub{}
|
||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
service.SetTokenCacheInvalidator(invalidator)
|
||
account := &Account{
|
||
ID: 100,
|
||
Platform: PlatformAntigravity,
|
||
Type: AccountTypeOAuth,
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 1, repo.setErrorCalls)
|
||
require.Equal(t, 0, repo.tempCalls)
|
||
require.Empty(t, invalidator.accounts)
|
||
})
|
||
}
|
||
|
||
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
|
||
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable。
|
||
// 注意:401 handler 不再回写 credentials(避免请求开始时的快照整列覆盖 DB
|
||
// 把另一个 worker 刚刷新出来的新 refresh_token 回滚为旧值),
|
||
// 因此 updateCredentialsCalls 应当为 0。
|
||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
service.SetTokenCacheInvalidator(invalidator)
|
||
account := &Account{
|
||
ID: 101,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Credentials: map[string]any{
|
||
"refresh_token": "rt-101",
|
||
},
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 0, repo.setErrorCalls)
|
||
require.Equal(t, 1, repo.tempCalls)
|
||
require.Equal(t, 0, repo.updateCredentialsCalls)
|
||
require.Len(t, invalidator.accounts, 1)
|
||
}
|
||
|
||
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
service.SetTokenCacheInvalidator(invalidator)
|
||
account := &Account{
|
||
ID: 102,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeAPIKey,
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 1, repo.setErrorCalls)
|
||
require.Empty(t, invalidator.accounts)
|
||
}
|
||
|
||
// TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials
|
||
// 回归测试:确保 401 handler 不再使用请求开始时的 account 快照写回 credentials。
|
||
// 原实现会通过 persistAccountCredentials → UpdateCredentials → SetCredentials
|
||
// 整列覆盖 credentials JSONB,在另一个 worker 刚刷新完 refresh_token 的窄窗口内
|
||
// 会把新 refresh_token 回滚为快照中的旧值,导致下一周期拿 invalid_grant 被错误 disable。
|
||
func TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
account := &Account{
|
||
ID: 103,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Credentials: map[string]any{
|
||
"access_token": "token",
|
||
"refresh_token": "rt-103",
|
||
},
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 0, repo.updateCredentialsCalls, "401 handler must not write credentials back from the request-start snapshot")
|
||
require.Equal(t, 1, repo.tempCalls, "401 handler should still set temp-unschedulable cooldown")
|
||
require.Nil(t, repo.lastCredentials, "no credentials should have been persisted")
|
||
}
|
||
|
||
// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
|
||
// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
|
||
func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
|
||
t.Run("openai_no_refresh_token", func(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
service.SetTokenCacheInvalidator(invalidator)
|
||
account := &Account{
|
||
ID: 2881,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Credentials: map[string]any{
|
||
"access_token": "expired-at",
|
||
// no refresh_token
|
||
},
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
|
||
require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
|
||
require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
|
||
require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
|
||
require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
|
||
})
|
||
|
||
t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
|
||
repo := &rateLimitAccountRepoStub{}
|
||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
account := &Account{
|
||
ID: 2882,
|
||
Platform: PlatformOpenAI,
|
||
Type: AccountTypeOAuth,
|
||
Credentials: map[string]any{
|
||
"access_token": "expired-at",
|
||
"refresh_token": " ",
|
||
},
|
||
}
|
||
|
||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 1, repo.setErrorCalls)
|
||
require.Equal(t, 0, repo.tempCalls)
|
||
})
|
||
}
|