fix(scheduler): 模型404仅冷却账号模型组合

This commit is contained in:
wucm667 2026-05-26 20:29:48 +08:00
parent 9ef144874a
commit a31b507484
26 changed files with 434 additions and 63 deletions

View File

@ -1077,7 +1077,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
if scope == "" {
return nil
}
@ -1086,6 +1086,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
if len(reason) > 0 {
if value := strings.TrimSpace(reason[0]); value != "" {
payload["reason"] = value
}
}
raw, err := json.Marshal(payload)
if err != nil {
return err
@ -1121,6 +1126,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
}

View File

@ -1731,7 +1731,7 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return errors.New("not implemented")
}

View File

@ -60,7 +60,7 @@ type AccountRepository interface {
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error

View File

@ -159,7 +159,7 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
panic("unexpected SetModelRateLimit call")
}

View File

@ -1312,22 +1312,6 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
return body, nil
}
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
//
// 限流处理流程:

View File

@ -94,7 +94,7 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
return nil
}
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error {
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time, reason ...string) error {
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
return nil
}

View File

@ -166,7 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,

View File

@ -163,7 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
Message: upstreamMsg,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,

View File

@ -156,7 +156,7 @@ func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx con
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {

View File

@ -4870,7 +4870,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(ctx, resp, account, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -4897,7 +4897,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@ -4933,11 +4933,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} else {
logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(ctx, resp, account, reqModel)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
}
// 处理正常响应
@ -5170,7 +5170,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(ctx, resp, account, input.RequestModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -5195,7 +5195,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
}
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
return s.handleErrorResponse(ctx, resp, c, account, input.RequestModel)
}
var usage *ClaudeUsage
@ -6959,7 +6959,7 @@ func isCountTokensUnsupported404(statusCode int, body []byte) bool {
return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found")
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, requestedModel ...string) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
@ -7006,7 +7006,11 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// 处理上游错误,标记账号状态
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
if len(requestedModel) > 0 {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
} else {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
@ -7122,8 +7126,12 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
}
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if len(requestedModel) > 0 {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}

View File

@ -147,7 +147,7 @@ func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx conte
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {

View File

@ -0,0 +1,44 @@
package service
import (
"net/http"
"strings"
)
var upstreamModelNotFoundKeywords = []string{"model not found", "unknown model", "not found"}
func isUpstreamModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != http.StatusNotFound {
return false
}
normalized := normalizeModelNotFoundBody(body)
if normalized == "" || !strings.Contains(normalized, "model") {
return false
}
return containsModelNotFoundKeyword(normalized)
}
func isModelNotFoundError(statusCode int, body []byte) bool {
return isUpstreamModelNotFoundError(statusCode, body) || statusCode == http.StatusNotFound
}
func containsModelNotFoundKeyword(normalizedBody string) bool {
if normalizedBody == "" {
return false
}
for _, keyword := range upstreamModelNotFoundKeywords {
if strings.Contains(normalizedBody, keyword) {
return true
}
}
return false
}
func normalizeModelNotFoundBody(body []byte) string {
if len(body) == 0 {
return ""
}
normalized := strings.ToLower(string(body))
normalized = strings.NewReplacer("_", " ", "-", " ", "\n", " ", "\r", " ", "\t", " ").Replace(normalized)
return strings.Join(strings.Fields(normalized), " ")
}

View File

@ -0,0 +1,66 @@
package service
import (
"net/http"
"testing"
)
func TestIsUpstreamModelNotFoundError(t *testing.T) {
tests := []struct {
name string
statusCode int
body []byte
want bool
}{
{
name: "404 model not found message",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"model not found"}}`),
want: true,
},
{
name: "404 model_not_found code",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"code":"model_not_found","message":"The requested model was not found"}}`),
want: true,
},
{
name: "404 unknown model message",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"unknown model gpt-5.4"}}`),
want: true,
},
{
name: "404 endpoint not found is not model specific",
statusCode: http.StatusNotFound,
body: []byte(`{"error":{"message":"endpoint not found"}}`),
want: false,
},
{
name: "404 arbitrary body is not model specific",
statusCode: http.StatusNotFound,
body: []byte(`404 page not found`),
want: false,
},
{
name: "non 404 does not match",
statusCode: http.StatusBadRequest,
body: []byte(`{"error":{"message":"model not found"}}`),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isUpstreamModelNotFoundError(tt.statusCode, tt.body); got != tt.want {
t.Fatalf("isUpstreamModelNotFoundError() = %v, want %v", got, tt.want)
}
})
}
}
func TestAntigravityModelNotFoundKeepsBare404Fallback(t *testing.T) {
if !isModelNotFoundError(http.StatusNotFound, []byte(`endpoint not found`)) {
t.Fatal("antigravity model-not-found helper should keep bare 404 fallback")
}
}

View File

@ -31,7 +31,7 @@ func isOpenAIAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI
}
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) bool {
stateCtx, cancel := openAIAccountStateContext(ctx)
defer cancel()
@ -41,6 +41,9 @@ func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Cont
if s == nil || account == nil || s.rateLimitService == nil {
return false
}
if len(requestedModel) > 0 && s.rateLimitService.HandleUpstreamModelNotFound(stateCtx, account, requestedModel[0], statusCode, responseBody) {
return true
}
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
if shouldDisable {
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")

View File

@ -57,6 +57,28 @@ func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T)
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIModelNotFound_DoesNotRuntimeBlockWholeAccount(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &OpenAIGatewayService{
rateLimitService: &RateLimitService{accountRepo: repo},
}
account := openAIModelNotFoundTempAccount()
shouldDisable := svc.handleOpenAIAccountUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, shouldDisable)
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
}
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}

View File

@ -521,6 +521,50 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
require.Equal(t, int64(32002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_ModelRateLimitOnlySkipsThatModel(t *testing.T) {
ctx := context.Background()
resetAt := time.Now().Add(30 * time.Minute).Format(time.RFC3339)
primary := Account{
ID: 32101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"gpt-5.4": map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
secondary := Account{
ID: 32102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}},
cfg: &config.Config{},
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.4", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(32102), account.ID)
account, err = svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.3", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(32101), account.ID)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(10103)

View File

@ -276,14 +276,14 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
}
// 9. Handle normal response
@ -358,8 +358,9 @@ func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
requestedModel ...string,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError, requestedModel...)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the

View File

@ -207,14 +207,14 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
}
// 8. Forward response

View File

@ -337,7 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
@ -345,7 +345,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
// Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account)
return s.handleAnthropicErrorResponse(resp, c, account, billingModel)
}
if account.Type == AccountTypeOAuth && promptCacheKey != "" {
@ -412,8 +412,9 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
requestedModel ...string,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError, requestedModel...)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from

View File

@ -188,14 +188,14 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, chatBody)
return s.handleErrorResponse(ctx, resp, c, account, chatBody, billingModel)
}
if clientStream {

View File

@ -1312,8 +1312,8 @@ func openAICompactSupportTier(account *Account) int {
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
// compact-support checks used during account selection.
func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool {
if account == nil || !account.IsSchedulable() || !account.IsOpenAI() {
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool {
if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
@ -1446,7 +1446,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(account) {
@ -1646,7 +1646,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if clearSticky {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
}
if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
@ -1924,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
fresh = current
}
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(fresh) {
@ -1938,7 +1938,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) {
return nil
}
return account
@ -1948,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil {
return nil
}
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(latest) {
@ -2067,8 +2067,12 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if len(requestedModel) > 0 {
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
return
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
@ -2850,14 +2854,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
s.handleFailoverSideEffects(ctx, resp, account, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
return s.handleErrorResponse(ctx, resp, c, account, body, billingModel)
}
defer func() { _ = resp.Body.Close() }()
@ -3352,7 +3356,8 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -3395,7 +3400,8 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
// 避免粘性路由继续复用刚被限流的账号。
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -4069,6 +4075,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
c *gin.Context,
account *Account,
requestBody []byte,
requestedModel ...string,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@ -4145,7 +4152,14 @@ func (s *OpenAIGatewayService) handleErrorResponse(
}
// Handle upstream error (mark account status)
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
var reqModel string
if len(requestedModel) > 0 {
reqModel = strings.TrimSpace(requestedModel[0])
}
if reqModel == "" {
reqModel, _, _ = extractOpenAIRequestMetaFromBody(requestBody)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
kind := "http_error"
if shouldDisable {
kind = "failover"
@ -4222,6 +4236,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
c *gin.Context,
account *Account,
writeError compatErrorWriter,
requestedModel ...string,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@ -4277,8 +4292,12 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
}
// Track rate limits and decide whether to trigger secondary failover.
var modelForCooldown string
if len(requestedModel) > 0 {
modelForCooldown = requestedModel[0]
}
shouldDisable := s.handleOpenAIAccountUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
c.Request.Context(), account, resp.StatusCode, resp.Header, body, modelForCooldown,
)
kind := "http_error"
if shouldDisable {

View File

@ -638,7 +638,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(upstreamCtx, resp, account)
s.handleFailoverSideEffects(upstreamCtx, resp, account, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -1188,7 +1188,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
Kind: "failover",
Message: upstreamMsg,
})
s.handleFailoverSideEffects(upstreamCtx, resp, account)
s.handleFailoverSideEffects(upstreamCtx, resp, account, requestModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -153,7 +153,7 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) (shouldDisable bool) {
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
@ -169,6 +169,10 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false
}
if len(requestedModel) > 0 && s.HandleUpstreamModelNotFound(ctx, account, requestedModel[0], statusCode, responseBody) {
return true
}
// 先尝试临时不可调度规则401除外
// 如果匹配成功,直接返回,不执行后续禁用逻辑
if statusCode != 401 {
@ -1616,9 +1620,51 @@ func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
const upstreamModelNotFoundCooldown = 30 * time.Minute
const upstreamModelNotFoundReason = "upstream_404_model_not_found"
const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) HandleUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string, statusCode int, responseBody []byte) bool {
if s == nil || account == nil || s.accountRepo == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
if !isUpstreamModelNotFoundError(statusCode, responseBody) {
return false
}
modelKey := modelRateLimitKeyForUpstreamModelNotFound(ctx, account, requestedModel)
if modelKey == "" {
return false
}
resetAt := time.Now().Add(upstreamModelNotFoundCooldown)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, resetAt, upstreamModelNotFoundReason); err != nil {
slog.Warn("upstream_model_not_found_set_model_rate_limit_failed", "account_id", account.ID, "model", modelKey, "error", err)
return true
}
slog.Info("upstream_model_not_found_model_rate_limited", "account_id", account.ID, "model", modelKey, "reset_at", resetAt)
return true
}
func modelRateLimitKeyForUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string) string {
modelKey := strings.TrimSpace(requestedModel)
if account == nil || modelKey == "" {
return modelKey
}
if account.Platform == PlatformAntigravity {
if resolved := strings.TrimSpace(resolveFinalAntigravityModelKey(ctx, account, modelKey)); resolved != "" {
return resolved
}
return modelKey
}
if mapped := strings.TrimSpace(account.GetMappedModel(modelKey)); mapped != "" {
return mapped
}
return modelKey
}
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false

View File

@ -0,0 +1,127 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type modelNotFoundRateLimitCall struct {
accountID int64
scope string
resetAt time.Time
reason string
}
type modelNotFoundAccountRepoStub struct {
mockAccountRepoForGemini
tempCalls int
modelRateLimitCalls []modelNotFoundRateLimitCall
modelRateLimitErr error
}
func (r *modelNotFoundAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *modelNotFoundAccountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
call := modelNotFoundRateLimitCall{
accountID: id,
scope: scope,
resetAt: resetAt,
}
if len(reason) > 0 {
call.reason = reason[0]
}
r.modelRateLimitCalls = append(r.modelRateLimitCalls, call)
return r.modelRateLimitErr
}
func TestRateLimitService_HandleUpstreamError_ModelNotFoundUsesModelRateLimit(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
call := repo.modelRateLimitCalls[0]
require.Equal(t, account.ID, call.accountID)
require.Equal(t, "gpt-5.4", call.scope)
require.Equal(t, upstreamModelNotFoundReason, call.reason)
require.WithinDuration(t, time.Now().Add(upstreamModelNotFoundCooldown), call.resetAt, 5*time.Second)
}
func TestRateLimitService_HandleUpstreamError_ModelNotFoundWriteFailureDoesNotTempUnschedule(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{modelRateLimitErr: errors.New("write failed")}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Zero(t, repo.tempCalls)
require.Len(t, repo.modelRateLimitCalls, 1)
}
func TestRateLimitService_HandleUpstreamError_Bare404KeepsTempUnschedulablePath(t *testing.T) {
repo := &modelNotFoundAccountRepoStub{}
svc := &RateLimitService{accountRepo: repo}
account := openAIModelNotFoundTempAccount()
handled := svc.HandleUpstreamError(
context.Background(),
account,
http.StatusNotFound,
http.Header{},
[]byte(`{"error":{"message":"endpoint not found"}}`),
"gpt-5.4",
)
require.True(t, handled)
require.Equal(t, 1, repo.tempCalls)
require.Empty(t, repo.modelRateLimitCalls)
}
func openAIModelNotFoundTempAccount() *Account {
return &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(http.StatusNotFound),
"keywords": []any{"not found"},
"duration_minutes": float64(10),
},
},
},
}
}

View File

@ -137,7 +137,7 @@ func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatforms(context.Cont
func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time) error {
func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error {
panic("unexpected")
}
func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error {