Merge pull request #2800 from wucm667/fix/scheduler-model-not-found-per-model-cooldown

fix(scheduler): 模型 404 仅冷却该账号-模型组合,不再封整个账号
This commit is contained in:
Wesley Liddick 2026-05-27 21:01:52 +08:00 committed by GitHub
commit 61ce79533e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 434 additions and 63 deletions

View File

@ -1077,7 +1077,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil return nil
} }
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
if scope == "" { if scope == "" {
return nil return nil
} }
@ -1086,6 +1086,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
"rate_limited_at": now.Format(time.RFC3339), "rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
} }
if len(reason) > 0 {
if value := strings.TrimSpace(reason[0]); value != "" {
payload["reason"] = value
}
}
raw, err := json.Marshal(payload) raw, err := json.Marshal(payload)
if err != nil { if err != nil {
return err return err
@ -1121,6 +1126,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
} }
r.syncSchedulerAccountSnapshot(ctx, id)
return nil return nil
} }

View File

@ -1731,7 +1731,7 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return errors.New("not implemented") return errors.New("not implemented")
} }
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return errors.New("not implemented") return errors.New("not implemented")
} }

View File

@ -60,7 +60,7 @@ type AccountRepository interface {
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error ClearTempUnschedulable(ctx context.Context, id int64) error

View File

@ -159,7 +159,7 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call") panic("unexpected SetRateLimited call")
} }
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
panic("unexpected SetModelRateLimit call") panic("unexpected SetModelRateLimit call")
} }

View File

@ -1312,22 +1312,6 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
return body, nil return body, nil
} }
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换) // Forward 转发 Claude 协议请求Claude → Gemini 转换)
// //
// 限流处理流程: // 限流处理流程:

View File

@ -94,7 +94,7 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
return nil return nil
} }
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error { func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time, reason ...string) error {
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt}) s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
return nil return nil
} }

View File

@ -166,7 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
Message: upstreamMsg, Message: upstreamMsg,
}) })
if s.rateLimitService != nil { if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
} }
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,

View File

@ -163,7 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
Message: upstreamMsg, Message: upstreamMsg,
}) })
if s.rateLimitService != nil { if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
} }
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,

View File

@ -156,7 +156,7 @@ func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx con
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {

View File

@ -4870,7 +4870,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
@ -4897,7 +4897,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil { if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream // ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account, reqModel)
} }
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
@ -4933,11 +4933,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} else { } else {
logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID)
} }
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account, reqModel)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
} }
} }
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account, reqModel)
} }
// 处理正常响应 // 处理正常响应
@ -5170,7 +5170,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account, input.RequestModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
@ -5195,7 +5195,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account, input.RequestModel)
} }
var usage *ClaudeUsage var usage *ClaudeUsage
@ -6959,7 +6959,7 @@ func isCountTokensUnsupported404(statusCode int, body []byte) bool {
return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found")
} }
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, requestedModel ...string) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应 // 调试日志:打印上游错误响应
@ -7006,7 +7006,11 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
shouldDisable := false shouldDisable := false
if s.rateLimitService != nil { if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) if len(requestedModel) > 0 {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
} else {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
} }
if shouldDisable { if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
@ -7122,8 +7126,12 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
} }
} }
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if len(requestedModel) > 0 {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
} }

View File

@ -147,7 +147,7 @@ func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx conte
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {

View File

@ -0,0 +1,44 @@
package service
import (
"net/http"
"strings"
)
var upstreamModelNotFoundKeywords = []string{"model not found", "unknown model", "not found"}
func isUpstreamModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != http.StatusNotFound {
return false
}
normalized := normalizeModelNotFoundBody(body)
if normalized == "" || !strings.Contains(normalized, "model") {
return false
}
return containsModelNotFoundKeyword(normalized)
}
func isModelNotFoundError(statusCode int, body []byte) bool {
return isUpstreamModelNotFoundError(statusCode, body) || statusCode == http.StatusNotFound
}
func containsModelNotFoundKeyword(normalizedBody string) bool {
if normalizedBody == "" {
return false
}
for _, keyword := range upstreamModelNotFoundKeywords {
if strings.Contains(normalizedBody, keyword) {
return true
}
}
return false
}
func normalizeModelNotFoundBody(body []byte) string {
if len(body) == 0 {
return ""
}
normalized := strings.ToLower(string(body))
normalized = strings.NewReplacer("_", " ", "-", " ", "\n", " ", "\r", " ", "\t", " ").Replace(normalized)
return strings.Join(strings.Fields(normalized), " ")
}

View File

@ -0,0 +1,66 @@
package service
import (
"net/http"
"testing"
)
func TestIsUpstreamModelNotFoundError(t *testing.T) {
tests := []struct {
name string
statusCode int
body []byte
want bool
}{
{
name: "404 model not found message",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"model not found"}}`),
want: true,
},
{
name: "404 model_not_found code",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"code":"model_not_found","message":"The requested model was not found"}}`),
want: true,
},
{
name: "404 unknown model message",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"unknown model gpt-5.4"}}`),
want: true,
},
{
name: "404 endpoint not found is not model specific",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"endpoint not found"}}`),
want: false,
},
{
name: "404 arbitrary body is not model specific",
statusCode: http.StatusNotFound,
body: []byte(`404 page not found`),
want: false,
},
{
name: "non 404 does not match",
statusCode: http.StatusBadRequest,
body: []byte(`{"error":{"message":"model not found"}}`),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isUpstreamModelNotFoundError(tt.statusCode, tt.body); got != tt.want {
t.Fatalf("isUpstreamModelNotFoundError() = %v, want %v", got, tt.want)
}
})
}
}
func TestAntigravityModelNotFoundKeepsBare404Fallback(t *testing.T) {
if !isModelNotFoundError(http.StatusNotFound, []byte(`endpoint not found`)) {
t.Fatal("antigravity model-not-found helper should keep bare 404 fallback")
}
}

View File

@ -31,7 +31,7 @@ func isOpenAIAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI return account != nil && account.Platform == PlatformOpenAI
} }
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool { func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) bool {
stateCtx, cancel := openAIAccountStateContext(ctx) stateCtx, cancel := openAIAccountStateContext(ctx)
defer cancel() defer cancel()
@ -41,6 +41,9 @@ func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Cont
if s == nil || account == nil || s.rateLimitService == nil { if s == nil || account == nil || s.rateLimitService == nil {
return false return false
} }
if len(requestedModel) > 0 && s.rateLimitService.HandleUpstreamModelNotFound(stateCtx, account, requestedModel[0], statusCode, responseBody) {
return true
}
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody) shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
if shouldDisable { if shouldDisable {
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable") s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")

View File

@ -57,6 +57,28 @@ func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T)
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account)) require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
} }
func TestOpenAIModelNotFound_DoesNotRuntimeBlockWholeAccount(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &OpenAIGatewayService{
rateLimitService: &RateLimitService{accountRepo: repo},
}
account := openAIModelNotFoundTempAccount()
shouldDisable := svc.handleOpenAIAccountUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, shouldDisable)
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
}
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) { func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
svc := &OpenAIGatewayService{} svc := &OpenAIGatewayService{}
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth} account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}

View File

@ -521,6 +521,50 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID) require.Equal(t, int64(32002), account.ID)
} }
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_ModelRateLimitOnlySkipsThatModel(t *testing.T) {
ctx := context.Background()
resetAt := time.Now().Add(30 * time.Minute).Format(time.RFC3339)
primary := Account{
ID: 32101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gpt-5.4": map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
secondary := Account{
ID: 32102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}},
cfg: &config.Config{},
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.4", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(32102), account.ID)
account, err = svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.3", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(32101), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) { func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background() ctx := context.Background()
groupID := int64(10103) groupID := int64(10103)

View File

@ -276,14 +276,14 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
} }
} }
return s.handleChatCompletionsErrorResponse(resp, c, account) return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
} }
// 9. Handle normal response // 9. Handle normal response
@ -358,8 +358,9 @@ func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
account *Account, account *Account,
requestedModel ...string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError, requestedModel...)
} }
// handleChatBufferedStreamingResponse reads all Responses SSE events from the // handleChatBufferedStreamingResponse reads all Responses SSE events from the

View File

@ -208,14 +208,14 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
} }
} }
return s.handleChatCompletionsErrorResponse(resp, c, account) return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
} }
// 8. Forward response // 8. Forward response

View File

@ -338,7 +338,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
@ -346,7 +346,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
} }
// Non-failover error: return Anthropic-formatted error to client // Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account) return s.handleAnthropicErrorResponse(resp, c, account, billingModel)
} }
if account.Type == AccountTypeOAuth && promptCacheKey != "" { if account.Type == AccountTypeOAuth && promptCacheKey != "" {
@ -413,8 +413,9 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
account *Account, account *Account,
requestedModel ...string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError, requestedModel...)
} }
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from

View File

@ -188,14 +188,14 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
Message: upstreamMsg, Message: upstreamMsg,
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, chatBody) return s.handleErrorResponse(ctx, resp, c, account, chatBody, billingModel)
} }
if clientStream { if clientStream {

View File

@ -1312,8 +1312,8 @@ func openAICompactSupportTier(account *Account) int {
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / // isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
// compact-support checks used during account selection. // compact-support checks used during account selection.
func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool { func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool {
if account == nil || !account.IsSchedulable() || !account.IsOpenAI() { if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false return false
} }
if requestedModel != "" && !account.IsModelSupported(requestedModel) { if requestedModel != "" && !account.IsModelSupported(requestedModel) {
@ -1446,7 +1446,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 验证账号是否可用于当前请求 // 验证账号是否可用于当前请求
// Verify account is usable for current request // Verify account is usable for current request
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
return nil return nil
} }
if s.isOpenAIAccountRuntimeBlocked(account) { if s.isOpenAIAccountRuntimeBlocked(account) {
@ -1646,7 +1646,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if clearSticky { if clearSticky {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} }
if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) { if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil { if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
@ -1924,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
fresh = current fresh = current
} }
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) {
return nil return nil
} }
if s.isOpenAIAccountRuntimeBlocked(fresh) { if s.isOpenAIAccountRuntimeBlocked(fresh) {
@ -1938,7 +1938,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
return nil return nil
} }
if s.schedulerSnapshot == nil || s.accountRepo == nil { if s.schedulerSnapshot == nil || s.accountRepo == nil {
if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) { if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) {
return nil return nil
} }
return account return account
@ -1948,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil { if err != nil || latest == nil {
return nil return nil
} }
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) {
return nil return nil
} }
if s.isOpenAIAccountRuntimeBlocked(latest) { if s.isOpenAIAccountRuntimeBlocked(latest) {
@ -2067,8 +2067,12 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
} }
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if len(requestedModel) > 0 {
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
return
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
} }
@ -2852,14 +2856,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
Detail: upstreamDetail, Detail: upstreamDetail,
}) })
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, body) return s.handleErrorResponse(ctx, resp, c, account, body, billingModel)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
@ -3344,7 +3348,8 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
} }
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
@ -3387,7 +3392,8 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新, // 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
// 避免粘性路由继续复用刚被限流的账号。 // 避免粘性路由继续复用刚被限流的账号。
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
@ -4061,6 +4067,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
c *gin.Context, c *gin.Context,
account *Account, account *Account,
requestBody []byte, requestBody []byte,
requestedModel ...string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@ -4137,7 +4144,14 @@ func (s *OpenAIGatewayService) handleErrorResponse(
} }
// Handle upstream error (mark account status) // Handle upstream error (mark account status)
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) var reqModel string
if len(requestedModel) > 0 {
reqModel = strings.TrimSpace(requestedModel[0])
}
if reqModel == "" {
reqModel, _, _ = extractOpenAIRequestMetaFromBody(requestBody)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
kind := "http_error" kind := "http_error"
if shouldDisable { if shouldDisable {
kind = "failover" kind = "failover"
@ -4214,6 +4228,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
c *gin.Context, c *gin.Context,
account *Account, account *Account,
writeError compatErrorWriter, writeError compatErrorWriter,
requestedModel ...string,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@ -4269,8 +4284,12 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
} }
// Track rate limits and decide whether to trigger secondary failover. // Track rate limits and decide whether to trigger secondary failover.
var modelForCooldown string
if len(requestedModel) > 0 {
modelForCooldown = requestedModel[0]
}
shouldDisable := s.handleOpenAIAccountUpstreamError( shouldDisable := s.handleOpenAIAccountUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body, c.Request.Context(), account, resp.StatusCode, resp.Header, body, modelForCooldown,
) )
kind := "http_error" kind := "http_error"
if shouldDisable { if shouldDisable {

View File

@ -638,7 +638,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
Kind: "failover", Kind: "failover",
Message: upstreamMsg, Message: upstreamMsg,
}) })
s.handleFailoverSideEffects(upstreamCtx, resp, account) s.handleFailoverSideEffects(upstreamCtx, resp, account, upstreamModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,

View File

@ -1188,7 +1188,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
Kind: "failover", Kind: "failover",
Message: upstreamMsg, Message: upstreamMsg,
}) })
s.handleFailoverSideEffects(upstreamCtx, resp, account) s.handleFailoverSideEffects(upstreamCtx, resp, account, requestModel)
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,

View File

@ -153,7 +153,7 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
// HandleUpstreamError 处理上游错误响应,标记账号状态 // HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度 // 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) (shouldDisable bool) {
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。 // 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
@ -169,6 +169,10 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false return false
} }
if len(requestedModel) > 0 && s.HandleUpstreamModelNotFound(ctx, account, requestedModel[0], statusCode, responseBody) {
return true
}
// 先尝试临时不可调度规则401除外 // 先尝试临时不可调度规则401除外
// 如果匹配成功,直接返回,不执行后续禁用逻辑 // 如果匹配成功,直接返回,不执行后续禁用逻辑
if statusCode != 401 { if statusCode != 401 {
@ -1616,9 +1620,51 @@ func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody) return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
} }
const upstreamModelNotFoundCooldown = 30 * time.Minute
const upstreamModelNotFoundReason = "upstream_404_model_not_found"
const tempUnschedBodyMaxBytes = 64 << 10 const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048 const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) HandleUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string, statusCode int, responseBody []byte) bool {
if s == nil || account == nil || s.accountRepo == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
if !isUpstreamModelNotFoundError(statusCode, responseBody) {
return false
}
modelKey := modelRateLimitKeyForUpstreamModelNotFound(ctx, account, requestedModel)
if modelKey == "" {
return false
}
resetAt := time.Now().Add(upstreamModelNotFoundCooldown)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, resetAt, upstreamModelNotFoundReason); err != nil {
slog.Warn("upstream_model_not_found_set_model_rate_limit_failed", "account_id", account.ID, "model", modelKey, "error", err)
return true
}
slog.Info("upstream_model_not_found_model_rate_limited", "account_id", account.ID, "model", modelKey, "reset_at", resetAt)
return true
}
func modelRateLimitKeyForUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string) string {
modelKey := strings.TrimSpace(requestedModel)
if account == nil || modelKey == "" {
return modelKey
}
if account.Platform == PlatformAntigravity {
if resolved := strings.TrimSpace(resolveFinalAntigravityModelKey(ctx, account, modelKey)); resolved != "" {
return resolved
}
return modelKey
}
if mapped := strings.TrimSpace(account.GetMappedModel(modelKey)); mapped != "" {
return mapped
}
return modelKey
}
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool { func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil { if account == nil {
return false return false

View File

@ -0,0 +1,127 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type modelNotFoundRateLimitCall struct {
accountID int64
scope string
resetAt time.Time
reason string
}
type modelNotFoundAccountRepoStub struct {
mockAccountRepoForGemini
tempCalls int
modelRateLimitCalls []modelNotFoundRateLimitCall
modelRateLimitErr error
}
func (r *modelNotFoundAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *modelNotFoundAccountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
call := modelNotFoundRateLimitCall{
accountID: id,
scope: scope,
resetAt: resetAt,
}
if len(reason) > 0 {
call.reason = reason[0]
}
r.modelRateLimitCalls = append(r.modelRateLimitCalls, call)
return r.modelRateLimitErr
}
func TestRateLimitService_HandleUpstreamError_ModelNotFoundUsesModelRateLimit(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, "gpt-5.4", call.scope)
require.Equal(t, upstreamModelNotFoundReason, call.reason)
require.WithinDuration(t, time.Now().Add(upstreamModelNotFoundCooldown), call.resetAt, 5*time.Second)
}
func TestRateLimitService_HandleUpstreamError_ModelNotFoundWriteFailureDoesNotTempUnschedule(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{modelRateLimitErr: errors.New("write failed")}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
}
func TestRateLimitService_HandleUpstreamError_Bare404KeepsTempUnschedulablePath(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"message":"endpoint not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Equal(t, 1, repo.tempCalls)
require.Empty(t, repo.modelRateLimitCalls)
}
func openAIModelNotFoundTempAccount() *Account {
return &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(http.StatusNotFound),
"keywords": []any{"not found"},
"duration_minutes": float64(10),
},
},
},
}
}

View File

@ -137,7 +137,7 @@ func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatforms(context.Cont
func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error { func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time) error { func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error {
panic("unexpected") panic("unexpected")
} }
func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error { func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error {