fix(scheduler): 模型404仅冷却账号模型组合
This commit is contained in:
parent
9ef144874a
commit
a31b507484
@ -1077,7 +1077,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
}
|
||||
@ -1086,6 +1086,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
if len(reason) > 0 {
|
||||
if value := strings.TrimSpace(reason[0]); value != "" {
|
||||
payload["reason"] = value
|
||||
}
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1121,6 +1126,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -1731,7 +1731,7 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ type AccountRepository interface {
|
||||
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
|
||||
@ -159,7 +159,7 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
|
||||
@ -1312,22 +1312,6 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
if statusCode != 404 {
|
||||
return false
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
keywords := []string{"model not found", "unknown model", "not found"}
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(bodyStr, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true // 404 without specific message also treated as model not found
|
||||
}
|
||||
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
//
|
||||
// 限流处理流程:
|
||||
|
||||
@ -94,7 +94,7 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error {
|
||||
func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time, reason ...string) error {
|
||||
s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
|
||||
@ -163,7 +163,7 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel)
|
||||
}
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
|
||||
@ -156,7 +156,7 @@ func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx con
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
|
||||
@ -4870,7 +4870,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(ctx, resp, account, reqModel)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -4897,7 +4897,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@ -4933,11 +4933,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(ctx, resp, account, reqModel)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
|
||||
}
|
||||
|
||||
// 处理正常响应
|
||||
@ -5170,7 +5170,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(ctx, resp, account, input.RequestModel)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -5195,7 +5195,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
return s.handleErrorResponse(ctx, resp, c, account, input.RequestModel)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
@ -6959,7 +6959,7 @@ func isCountTokensUnsupported404(statusCode int, body []byte) bool {
|
||||
return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found")
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, requestedModel ...string) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
@ -7006,7 +7006,11 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
// 处理上游错误,标记账号状态
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
if len(requestedModel) > 0 {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
|
||||
} else {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
@ -7122,8 +7126,12 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if len(requestedModel) > 0 {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx conte
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
|
||||
44
backend/internal/service/model_not_found_error.go
Normal file
44
backend/internal/service/model_not_found_error.go
Normal file
@ -0,0 +1,44 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var upstreamModelNotFoundKeywords = []string{"model not found", "unknown model", "not found"}
|
||||
|
||||
func isUpstreamModelNotFoundError(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusNotFound {
|
||||
return false
|
||||
}
|
||||
normalized := normalizeModelNotFoundBody(body)
|
||||
if normalized == "" || !strings.Contains(normalized, "model") {
|
||||
return false
|
||||
}
|
||||
return containsModelNotFoundKeyword(normalized)
|
||||
}
|
||||
|
||||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
return isUpstreamModelNotFoundError(statusCode, body) || statusCode == http.StatusNotFound
|
||||
}
|
||||
|
||||
func containsModelNotFoundKeyword(normalizedBody string) bool {
|
||||
if normalizedBody == "" {
|
||||
return false
|
||||
}
|
||||
for _, keyword := range upstreamModelNotFoundKeywords {
|
||||
if strings.Contains(normalizedBody, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeModelNotFoundBody(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := strings.ToLower(string(body))
|
||||
normalized = strings.NewReplacer("_", " ", "-", " ", "\n", " ", "\r", " ", "\t", " ").Replace(normalized)
|
||||
return strings.Join(strings.Fields(normalized), " ")
|
||||
}
|
||||
66
backend/internal/service/model_not_found_error_test.go
Normal file
66
backend/internal/service/model_not_found_error_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsUpstreamModelNotFoundError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "404 model not found message",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: []byte(`{"error":{"message":"model not found"}}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "404 model_not_found code",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: []byte(`{"error":{"code":"model_not_found","message":"The requested model was not found"}}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "404 unknown model message",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: []byte(`{"error":{"message":"unknown model gpt-5.4"}}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "404 endpoint not found is not model specific",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: []byte(`{"error":{"message":"endpoint not found"}}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "404 arbitrary body is not model specific",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: []byte(`404 page not found`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non 404 does not match",
|
||||
statusCode: http.StatusBadRequest,
|
||||
body: []byte(`{"error":{"message":"model not found"}}`),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isUpstreamModelNotFoundError(tt.statusCode, tt.body); got != tt.want {
|
||||
t.Fatalf("isUpstreamModelNotFoundError() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityModelNotFoundKeepsBare404Fallback(t *testing.T) {
|
||||
if !isModelNotFoundError(http.StatusNotFound, []byte(`endpoint not found`)) {
|
||||
t.Fatal("antigravity model-not-found helper should keep bare 404 fallback")
|
||||
}
|
||||
}
|
||||
@ -31,7 +31,7 @@ func isOpenAIAccount(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
|
||||
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) bool {
|
||||
stateCtx, cancel := openAIAccountStateContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
@ -41,6 +41,9 @@ func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Cont
|
||||
if s == nil || account == nil || s.rateLimitService == nil {
|
||||
return false
|
||||
}
|
||||
if len(requestedModel) > 0 && s.rateLimitService.HandleUpstreamModelNotFound(stateCtx, account, requestedModel[0], statusCode, responseBody) {
|
||||
return true
|
||||
}
|
||||
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
|
||||
if shouldDisable {
|
||||
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||
|
||||
@ -57,6 +57,28 @@ func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T)
|
||||
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIModelNotFound_DoesNotRuntimeBlockWholeAccount(t *testing.T) {
|
||||
repo := &modelNotFoundAccountRepoStub{}
|
||||
svc := &OpenAIGatewayService{
|
||||
rateLimitService: &RateLimitService{accountRepo: repo},
|
||||
}
|
||||
account := openAIModelNotFoundTempAccount()
|
||||
|
||||
shouldDisable := svc.handleOpenAIAccountUpstreamError(
|
||||
context.Background(),
|
||||
account,
|
||||
http.StatusNotFound,
|
||||
http.Header{},
|
||||
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
|
||||
"gpt-5.4",
|
||||
)
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
require.Zero(t, repo.tempCalls)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
@ -521,6 +521,50 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
|
||||
require.Equal(t, int64(32002), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_ModelRateLimitOnlySkipsThatModel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resetAt := time.Now().Add(30 * time.Minute).Format(time.RFC3339)
|
||||
primary := Account{
|
||||
ID: 32101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"gpt-5.4": map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
secondary := Account{
|
||||
ID: 32102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 5,
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}},
|
||||
cfg: &config.Config{},
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.4", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32102), account.ID)
|
||||
|
||||
account, err = svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.3", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, int64(32101), account.ID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10103)
|
||||
|
||||
@ -276,14 +276,14 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
|
||||
}
|
||||
|
||||
// 9. Handle normal response
|
||||
@ -358,8 +358,9 @@ func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
requestedModel ...string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
|
||||
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError, requestedModel...)
|
||||
}
|
||||
|
||||
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
|
||||
|
||||
@ -207,14 +207,14 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account)
|
||||
return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
|
||||
}
|
||||
|
||||
// 8. Forward response
|
||||
|
||||
@ -337,7 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
@ -345,7 +345,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
}
|
||||
// Non-failover error: return Anthropic-formatted error to client
|
||||
return s.handleAnthropicErrorResponse(resp, c, account)
|
||||
return s.handleAnthropicErrorResponse(resp, c, account, billingModel)
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth && promptCacheKey != "" {
|
||||
@ -412,8 +412,9 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
requestedModel ...string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
|
||||
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError, requestedModel...)
|
||||
}
|
||||
|
||||
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
|
||||
|
||||
@ -188,14 +188,14 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, chatBody)
|
||||
return s.handleErrorResponse(ctx, resp, c, account, chatBody, billingModel)
|
||||
}
|
||||
|
||||
if clientStream {
|
||||
|
||||
@ -1312,8 +1312,8 @@ func openAICompactSupportTier(account *Account) int {
|
||||
|
||||
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
|
||||
// compact-support checks used during account selection.
|
||||
func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool {
|
||||
if account == nil || !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool {
|
||||
if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||
return false
|
||||
}
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
@ -1446,7 +1446,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
||||
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||
@ -1646,7 +1646,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
if clearSticky {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
}
|
||||
if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
||||
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
@ -1924,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
fresh = current
|
||||
}
|
||||
|
||||
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
|
||||
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
||||
@ -1938,7 +1938,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
if s.schedulerSnapshot == nil || s.accountRepo == nil {
|
||||
if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) {
|
||||
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
@ -1948,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
||||
if err != nil || latest == nil {
|
||||
return nil
|
||||
}
|
||||
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
|
||||
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||
@ -2067,8 +2067,12 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
|
||||
return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if len(requestedModel) > 0 {
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0])
|
||||
return
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
@ -2850,14 +2854,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
s.handleFailoverSideEffects(ctx, resp, account, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body, billingModel)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@ -3352,7 +3356,8 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -3395,7 +3400,8 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
|
||||
// 避免粘性路由继续复用刚被限流的账号。
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody)
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -4069,6 +4075,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
requestBody []byte,
|
||||
requestedModel ...string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
@ -4145,7 +4152,14 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
var reqModel string
|
||||
if len(requestedModel) > 0 {
|
||||
reqModel = strings.TrimSpace(requestedModel[0])
|
||||
}
|
||||
if reqModel == "" {
|
||||
reqModel, _, _ = extractOpenAIRequestMetaFromBody(requestBody)
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
@ -4222,6 +4236,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
writeError compatErrorWriter,
|
||||
requestedModel ...string,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
@ -4277,8 +4292,12 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
}
|
||||
|
||||
// Track rate limits and decide whether to trigger secondary failover.
|
||||
var modelForCooldown string
|
||||
if len(requestedModel) > 0 {
|
||||
modelForCooldown = requestedModel[0]
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body, modelForCooldown,
|
||||
)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
|
||||
@ -638,7 +638,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account, upstreamModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -1188,7 +1188,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
})
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account)
|
||||
s.handleFailoverSideEffects(upstreamCtx, resp, account, requestModel)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -153,7 +153,7 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) (shouldDisable bool) {
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
|
||||
// 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。
|
||||
@ -169,6 +169,10 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return false
|
||||
}
|
||||
|
||||
if len(requestedModel) > 0 && s.HandleUpstreamModelNotFound(ctx, account, requestedModel[0], statusCode, responseBody) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 先尝试临时不可调度规则(401除外)
|
||||
// 如果匹配成功,直接返回,不执行后续禁用逻辑
|
||||
if statusCode != 401 {
|
||||
@ -1616,9 +1620,51 @@ func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account
|
||||
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
}
|
||||
|
||||
const upstreamModelNotFoundCooldown = 30 * time.Minute
|
||||
const upstreamModelNotFoundReason = "upstream_404_model_not_found"
|
||||
const tempUnschedBodyMaxBytes = 64 << 10
|
||||
const tempUnschedMessageMaxBytes = 2048
|
||||
|
||||
func (s *RateLimitService) HandleUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string, statusCode int, responseBody []byte) bool {
|
||||
if s == nil || account == nil || s.accountRepo == nil {
|
||||
return false
|
||||
}
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
return false
|
||||
}
|
||||
if !isUpstreamModelNotFoundError(statusCode, responseBody) {
|
||||
return false
|
||||
}
|
||||
modelKey := modelRateLimitKeyForUpstreamModelNotFound(ctx, account, requestedModel)
|
||||
if modelKey == "" {
|
||||
return false
|
||||
}
|
||||
resetAt := time.Now().Add(upstreamModelNotFoundCooldown)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, resetAt, upstreamModelNotFoundReason); err != nil {
|
||||
slog.Warn("upstream_model_not_found_set_model_rate_limit_failed", "account_id", account.ID, "model", modelKey, "error", err)
|
||||
return true
|
||||
}
|
||||
slog.Info("upstream_model_not_found_model_rate_limited", "account_id", account.ID, "model", modelKey, "reset_at", resetAt)
|
||||
return true
|
||||
}
|
||||
|
||||
func modelRateLimitKeyForUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string) string {
|
||||
modelKey := strings.TrimSpace(requestedModel)
|
||||
if account == nil || modelKey == "" {
|
||||
return modelKey
|
||||
}
|
||||
if account.Platform == PlatformAntigravity {
|
||||
if resolved := strings.TrimSpace(resolveFinalAntigravityModelKey(ctx, account, modelKey)); resolved != "" {
|
||||
return resolved
|
||||
}
|
||||
return modelKey
|
||||
}
|
||||
if mapped := strings.TrimSpace(account.GetMappedModel(modelKey)); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
return modelKey
|
||||
}
|
||||
|
||||
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
|
||||
@ -0,0 +1,127 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type modelNotFoundRateLimitCall struct {
|
||||
accountID int64
|
||||
scope string
|
||||
resetAt time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
type modelNotFoundAccountRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
modelRateLimitCalls []modelNotFoundRateLimitCall
|
||||
modelRateLimitErr error
|
||||
}
|
||||
|
||||
func (r *modelNotFoundAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *modelNotFoundAccountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
call := modelNotFoundRateLimitCall{
|
||||
accountID: id,
|
||||
scope: scope,
|
||||
resetAt: resetAt,
|
||||
}
|
||||
if len(reason) > 0 {
|
||||
call.reason = reason[0]
|
||||
}
|
||||
r.modelRateLimitCalls = append(r.modelRateLimitCalls, call)
|
||||
return r.modelRateLimitErr
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_ModelNotFoundUsesModelRateLimit(t *testing.T) {
|
||||
repo := &modelNotFoundAccountRepoStub{}
|
||||
svc := &RateLimitService{accountRepo: repo}
|
||||
account := openAIModelNotFoundTempAccount()
|
||||
|
||||
handled := svc.HandleUpstreamError(
|
||||
context.Background(),
|
||||
account,
|
||||
http.StatusNotFound,
|
||||
http.Header{},
|
||||
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
|
||||
"gpt-5.4",
|
||||
)
|
||||
|
||||
require.True(t, handled)
|
||||
require.Zero(t, repo.tempCalls)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
call := repo.modelRateLimitCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, "gpt-5.4", call.scope)
|
||||
require.Equal(t, upstreamModelNotFoundReason, call.reason)
|
||||
require.WithinDuration(t, time.Now().Add(upstreamModelNotFoundCooldown), call.resetAt, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_ModelNotFoundWriteFailureDoesNotTempUnschedule(t *testing.T) {
|
||||
repo := &modelNotFoundAccountRepoStub{modelRateLimitErr: errors.New("write failed")}
|
||||
svc := &RateLimitService{accountRepo: repo}
|
||||
account := openAIModelNotFoundTempAccount()
|
||||
|
||||
handled := svc.HandleUpstreamError(
|
||||
context.Background(),
|
||||
account,
|
||||
http.StatusNotFound,
|
||||
http.Header{},
|
||||
[]byte(`{"error":{"code":"model_not_found","message":"model not found"}}`),
|
||||
"gpt-5.4",
|
||||
)
|
||||
|
||||
require.True(t, handled)
|
||||
require.Zero(t, repo.tempCalls)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_Bare404KeepsTempUnschedulablePath(t *testing.T) {
|
||||
repo := &modelNotFoundAccountRepoStub{}
|
||||
svc := &RateLimitService{accountRepo: repo}
|
||||
account := openAIModelNotFoundTempAccount()
|
||||
|
||||
handled := svc.HandleUpstreamError(
|
||||
context.Background(),
|
||||
account,
|
||||
http.StatusNotFound,
|
||||
http.Header{},
|
||||
[]byte(`{"error":{"message":"endpoint not found"}}`),
|
||||
"gpt-5.4",
|
||||
)
|
||||
|
||||
require.True(t, handled)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
}
|
||||
|
||||
func openAIModelNotFoundTempAccount() *Account {
|
||||
return &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(http.StatusNotFound),
|
||||
"keywords": []any{"not found"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -137,7 +137,7 @@ func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatforms(context.Cont
|
||||
func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time) error {
|
||||
func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user