From 1e406fed52f82b7e4812c02ad0d668b5446b97c2 Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 23 May 2026 10:18:43 +0800 Subject: [PATCH] fix: optimize OpenAI account cooldown scheduling --- backend/cmd/server/wire_gen.go | 70 +-- backend/internal/config/config.go | 7 +- backend/internal/config/config_test.go | 8 + .../handler/openai_chat_completions.go | 16 +- .../handler/openai_gateway_handler.go | 32 +- backend/internal/handler/openai_images.go | 16 +- backend/internal/server/api_contract_test.go | 2 +- backend/internal/service/admin_service.go | 6 + .../service/admin_service_clear_error_test.go | 4 +- .../internal/service/concurrency_service.go | 176 ++++++- .../service/concurrency_service_test.go | 45 ++ .../openai_account_runtime_block_fastpath.go | 169 +++++++ ...nai_account_runtime_block_fastpath_test.go | 101 ++++ .../service/openai_account_scheduler.go | 474 ++++++++++-------- .../openai_gateway_chat_completions.go | 4 +- .../openai_gateway_chat_completions_raw.go | 4 +- .../service/openai_gateway_messages.go | 4 +- .../openai_gateway_responses_chat_fallback.go | 4 +- .../service/openai_gateway_service.go | 257 ++++++---- .../internal/service/openai_token_provider.go | 8 + .../service/openai_token_provider_test.go | 5 + .../internal/service/openai_ws_forwarder.go | 2 +- backend/internal/service/ratelimit_service.go | 43 ++ .../service/ratelimit_service_403_test.go | 24 + .../service/ratelimit_service_clear_test.go | 3 + .../scheduler_snapshot_hydration_test.go | 25 + .../internal/service/token_refresh_service.go | 23 + backend/internal/service/wire.go | 4 + deploy/config.example.yaml | 3 + 29 files changed, 1169 insertions(+), 370 deletions(-) create mode 100644 backend/internal/service/openai_account_runtime_block_fastpath.go create mode 100644 backend/internal/service/openai_account_runtime_block_fastpath_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f5f747cb..e39a645a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) + gatewayCache := repository.NewGatewayCache(redisClient) + schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) + schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) - adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) - sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - rpmCache := repository.NewRPMCache(redisClient) - groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) - groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) - claudeOAuthClient := repository.NewClaudeOAuthClient() - oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) - openAIOAuthClient := repository.NewOpenAIOAuthClient() - openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) - geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) - geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - driveClient := repository.NewGeminiDriveClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) - antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) + pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) + pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) + if err != nil { + return nil, err + } + billingService := service.NewBillingService(configConfig, pricingService) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) @@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) + deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) + openAIOAuthClient := repository.NewOpenAIOAuthClient() + openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) + channelRepository := repository.NewChannelRepository(db) + channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) + modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) + notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) + sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) + rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) + claudeOAuthClient := repository.NewClaudeOAuthClient() + oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) + geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) + geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() + driveClient := repository.NewGeminiDriveClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) + antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() @@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) - gatewayCache := repository.NewGatewayCache(redisClient) - schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) - schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) @@ -173,24 +188,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) - usageBillingRepository := repository.NewUsageBillingRepository(client, db) - pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) - pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) - if err != nil { - return nil, err - } - billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) - deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) digestSessionStore := service.NewDigestSessionStore() - channelRepository := repository.NewChannelRepository(db) - channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) - modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) - balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) - openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -261,7 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5048e791..661e3296 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1009,7 +1009,8 @@ type GatewaySchedulingConfig struct { FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` // 负载计算 - LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"` // 快照桶读取时的 MGET 分块大小 SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` // 快照重建时的缓存写入分块大小 @@ -1828,6 +1829,7 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200) viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) @@ -2634,6 +2636,9 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") } + if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 { + return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative") + } if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 { return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a47de2f8..99fec46c 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { if !cfg.Gateway.Scheduling.LoadBatchEnabled { t.Fatalf("LoadBatchEnabled = false, want true") } + if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 { + t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) + } if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) } @@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, wantErr: "gateway.scheduling.sticky_session_max_waiting", }, + { + name: "gateway scheduling load batch cache ttl", + mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 }, + wantErr: "gateway.scheduling.load_batch_cache_ttl_ms", + }, { name: "gateway scheduling outbox poll", mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 }, diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index f7269214..4d523dba 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai_chat_completions.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index e7ba699d..9c5560f5 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), @@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if channelMappingMsg.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel) } - result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai_messages.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 9cd46dc0..e6c37272 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -258,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai.images.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 7b4c70cf..fcf37ed7 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1258,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 843af013..44684206 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -531,6 +531,7 @@ type adminServiceImpl struct { defaultSubAssigner DefaultSubscriptionAssigner userSubRepo UserSubscriptionRepository privacyClientFactory PrivacyClientFactory + runtimeBlocker AccountRuntimeBlocker } type userGroupRateBatchReader interface { @@ -556,6 +557,7 @@ func NewAdminService( defaultSubAssigner DefaultSubscriptionAssigner, userSubRepo UserSubscriptionRepository, privacyClientFactory PrivacyClientFactory, + runtimeBlocker AccountRuntimeBlocker, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -575,6 +577,7 @@ func NewAdminService( defaultSubAssigner: defaultSubAssigner, userSubRepo: userSubRepo, privacyClientFactory: privacyClientFactory, + runtimeBlocker: runtimeBlocker, } } @@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil { return nil, err } + if s.runtimeBlocker != nil { + s.runtimeBlocker.ClearAccountSchedulingBlock(id) + } return s.accountRepo.GetByID(ctx, id) } diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go index 141466dc..41418c4f 100644 --- a/backend/internal/service/admin_service_clear_error_test.go +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes TempUnschedulableReason: "missing refresh token", }, } - svc := &adminServiceImpl{accountRepo: repo} + blocker := &runtimeBlockRecorder{} + svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker} updated, err := svc.ClearAccountError(context.Background(), 31) require.NoError(t, err) @@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes require.Nil(t, updated.RateLimitResetAt) require.Nil(t, updated.TempUnschedulableUntil) require.Empty(t, updated.TempUnschedulableReason) + require.Equal(t, []int64{31}, blocker.clearedIDs) } diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 386d5ed0..712fc1a7 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,13 +3,17 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/binary" + "encoding/hex" "os" "strconv" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // ConcurrencyCache 定义并发控制的缓存接口 @@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error } const ( - // Default extra wait slots beyond concurrency limit + // 默认等待队列额外槽位 defaultExtraWaitSlots = 20 + + defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond + accountLoadBatchFetchTimeout = 3 * time.Second + maxAccountLoadBatchCacheEntries = 256 ) -// ConcurrencyService manages concurrent request limiting for accounts and users +// ConcurrencyService 管理账号和用户的并发限制。 type ConcurrencyService struct { cache ConcurrencyCache + + accountLoadCacheTTL atomic.Int64 + accountLoadCacheMu sync.RWMutex + accountLoadCache map[string]cachedAccountLoadBatch + accountLoadGroup singleflight.Group } -// NewConcurrencyService creates a new ConcurrencyService +type cachedAccountLoadBatch struct { + loadMap map[int64]*AccountLoadInfo + expiresAt time.Time +} + +// NewConcurrencyService 创建并发控制服务。 func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService { - return &ConcurrencyService{cache: cache} + svc := &ConcurrencyService{ + cache: cache, + accountLoadCache: make(map[string]cachedAccountLoadBatch), + } + svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL) + return svc +} + +// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。 +func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) { + if s == nil { + return + } + s.accountLoadCacheTTL.Store(int64(ttl)) + if ttl <= 0 { + s.accountLoadCacheMu.Lock() + s.accountLoadCache = make(map[string]cachedAccountLoadBatch) + s.accountLoadCacheMu.Unlock() + } } // AcquireResult represents the result of acquiring a concurrency slot @@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } -// GetAccountsLoadBatch returns load info for multiple accounts. +// GetAccountsLoadBatch 批量获取账号负载信息。 func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return s.getAccountsLoadBatch(ctx, accounts, true) +} + +// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。 +func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return s.getAccountsLoadBatch(ctx, accounts, false) +} + +func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) { + if len(accounts) == 0 { + return map[int64]*AccountLoadInfo{}, nil + } if s.cache == nil { return map[int64]*AccountLoadInfo{}, nil } - return s.cache.GetAccountsLoadBatch(ctx, accounts) + + ttl := time.Duration(s.accountLoadCacheTTL.Load()) + if !allowCache || ttl <= 0 { + return s.fetchAccountsLoadBatch(ctx, accounts) + } + + key := accountLoadBatchCacheKey(accounts) + if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok { + return cached, nil + } + + value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) { + now := time.Now() + if cached, ok := s.getCachedAccountLoadBatch(key, now); ok { + return cached, nil + } + loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts) + if fetchErr != nil { + return nil, fetchErr + } + cached := cloneAccountLoadMap(loadMap) + s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl)) + return cached, nil + }) + if err != nil { + return nil, err + } + loadMap, _ := value.(map[int64]*AccountLoadInfo) + if loadMap == nil { + return map[int64]*AccountLoadInfo{}, nil + } + return loadMap, nil +} + +func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if s.cache == nil { + return map[int64]*AccountLoadInfo{}, nil + } + baseCtx := context.Background() + if ctx != nil { + baseCtx = context.WithoutCancel(ctx) + } + redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout) + defer cancel() + return s.cache.GetAccountsLoadBatch(redisCtx, accounts) +} + +func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) { + s.accountLoadCacheMu.RLock() + cached, ok := s.accountLoadCache[key] + s.accountLoadCacheMu.RUnlock() + if !ok { + return nil, false + } + if !now.Before(cached.expiresAt) { + s.accountLoadCacheMu.Lock() + if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) { + delete(s.accountLoadCache, key) + } + s.accountLoadCacheMu.Unlock() + return nil, false + } + return cached.loadMap, true +} + +func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) { + s.accountLoadCacheMu.Lock() + if s.accountLoadCache == nil { + s.accountLoadCache = make(map[string]cachedAccountLoadBatch) + } + if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries { + now := time.Now() + for cacheKey, cached := range s.accountLoadCache { + if !now.Before(cached.expiresAt) { + delete(s.accountLoadCache, cacheKey) + } + } + for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries { + for cacheKey := range s.accountLoadCache { + delete(s.accountLoadCache, cacheKey) + break + } + } + } + s.accountLoadCache[key] = cachedAccountLoadBatch{ + loadMap: loadMap, + expiresAt: expiresAt, + } + s.accountLoadCacheMu.Unlock() +} + +func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string { + hash := sha256.New() + var buf [16]byte + for _, account := range accounts { + binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID)) + binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency))) + _, _ = hash.Write(buf[:]) + } + sum := hash.Sum(nil) + return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum) +} + +func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo { + if len(loadMap) == 0 { + return map[int64]*AccountLoadInfo{} + } + clone := make(map[int64]*AccountLoadInfo, len(loadMap)) + for accountID, loadInfo := range loadMap { + if loadInfo == nil { + clone[accountID] = nil + continue + } + copied := *loadInfo + clone[accountID] = &copied + } + return clone } // GetUsersLoadBatch returns load info for multiple users. diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 078ba0dc..7d5f501d 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -7,7 +7,9 @@ import ( "errors" "strconv" "strings" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct { // 记录调用 releasedAccountIDs []int64 releasedRequestIDs []string + loadBatchCalls atomic.Int64 } var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) @@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in return nil } func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + c.loadBatchCalls.Add(1) return c.loadBatch, c.loadBatchErr } func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { @@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) { require.Empty(t, result) } +func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) { + cache := &stubConcurrencyCacheForTest{ + loadBatch: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20}, + }, + } + svc := NewConcurrencyService(cache) + svc.SetAccountLoadBatchCacheTTL(time.Second) + + accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}} + first, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 1, first[int64(1)].CurrentConcurrency) + + cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80} + second, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 1, second[int64(1)].CurrentConcurrency) + require.Equal(t, int64(1), cache.loadBatchCalls.Load()) +} + +func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) { + cache := &stubConcurrencyCacheForTest{ + loadBatch: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20}, + }, + } + svc := NewConcurrencyService(cache) + svc.SetAccountLoadBatchCacheTTL(time.Second) + + accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}} + _, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + + cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80} + fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency) + require.Equal(t, int64(2), cache.loadBatchCalls.Load()) +} + func TestIncrementWaitCount_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitAllowed: true} svc := NewConcurrencyService(cache) diff --git a/backend/internal/service/openai_account_runtime_block_fastpath.go b/backend/internal/service/openai_account_runtime_block_fastpath.go new file mode 100644 index 00000000..41a309fd --- /dev/null +++ b/backend/internal/service/openai_account_runtime_block_fastpath.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "net/http" + "time" +) + +const ( + openAIAccountStateUpdateTimeout = 5 * time.Second + openAIOAuth429FallbackCooldown = 5 * time.Second + openAIStopSchedulingBridgeCooldown = 2 * time.Minute + openAIOAuth429StormWindow = 10 * time.Second + openAIOAuth429StormThreshold = 20 + openAIOAuth429StormMaxAccountSwitches = 1 +) + +func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, openAIAccountStateUpdateTimeout) +} + +func isOpenAIOAuthAccount(account *Account) bool { + return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth +} + +func isOpenAIAccount(account *Account) bool { + return account != nil && account.Platform == PlatformOpenAI +} + +func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool { + stateCtx, cancel := openAIAccountStateContext(ctx) + defer cancel() + + if statusCode == http.StatusTooManyRequests { + s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody) + } + if s == nil || account == nil || s.rateLimitService == nil { + return false + } + shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody) + if shouldDisable { + s.BlockAccountScheduling(account, time.Time{}, "upstream_disable") + } + return shouldDisable +} + +func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { + if s == nil || !isOpenAIOAuthAccount(account) { + return + } + s.recordOpenAIOAuth429() + + cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown) + if s.rateLimitService != nil { + if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) { + cooldownUntil = *resetAt + } else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil { + if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) { + cooldownUntil = resetAt + } + } else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 { + cooldownUntil = time.Now().Add(cooldown) + } + } + s.BlockAccountScheduling(account, cooldownUntil, "429") +} + +func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) { + if s == nil || !isOpenAIAccount(account) { + return + } + now := time.Now() + blockUntil := until + if blockUntil.IsZero() || !blockUntil.After(now) { + blockUntil = now.Add(openAIStopSchedulingBridgeCooldown) + } + + for { + current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID) + if !loaded { + actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil) + if !stored { + return + } + current = actual + } + + currentUntil, ok := current.(time.Time) + if !ok || currentUntil.IsZero() { + if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) { + return + } + continue + } + if currentUntil.After(blockUntil) { + return + } + if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) { + return + } + } +} + +func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiAccountRuntimeBlockUntil.Delete(accountID) +} + +func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool { + if s == nil || !isOpenAIAccount(account) { + return false + } + value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID) + if !ok { + return false + } + cooldownUntil, ok := value.(time.Time) + if !ok || cooldownUntil.IsZero() { + s.openaiAccountRuntimeBlockUntil.Delete(account.ID) + return false + } + if time.Now().Before(cooldownUntil) { + return true + } + s.openaiAccountRuntimeBlockUntil.Delete(account.ID) + return false +} + +func (s *OpenAIGatewayService) recordOpenAIOAuth429() { + if s == nil { + return + } + now := time.Now() + windowStart := s.openaiOAuth429WindowStartUnixNano.Load() + if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow { + if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) { + s.openaiOAuth429WindowCount.Store(1) + return + } + } + s.openaiOAuth429WindowCount.Add(1) +} + +func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool { + if s == nil { + return false + } + windowStart := s.openaiOAuth429WindowStartUnixNano.Load() + if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow { + return false + } + return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold +} + +func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool { + if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches { + return false + } + if !isOpenAIOAuthAccount(account) { + return false + } + return s.isOpenAIOAuth429Storm() +} diff --git a/backend/internal/service/openai_account_runtime_block_fastpath_test.go b/backend/internal/service/openai_account_runtime_block_fastpath_test.go new file mode 100644 index 00000000..95336e81 --- /dev/null +++ b/backend/internal/service/openai_account_runtime_block_fastpath_test.go @@ -0,0 +1,101 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil) + apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil) + + require.False(t, shouldDisable) + require.False(t, apiKeyShouldDisable) + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) + require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount)) +} + +func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code") + + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth} + + svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code") + + require.False(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) { + gateway := &OpenAIGatewayService{} + repo := &rateLimitAccountRepoStub{} + rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + rateLimitService.SetAccountRuntimeBlocker(gateway) + account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth} + + shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden")) + + require.True(t, shouldDisable) + require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + longUntil := time.Now().Add(10 * time.Minute) + + svc.BlockAccountScheduling(account, longUntil, "oauth_401") + svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable") + + value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID) + require.True(t, ok) + actualUntil, ok := value.(time.Time) + require.True(t, ok) + require.WithinDuration(t, longUntil, actualUntil, time.Second) +} + +func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429") + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) + + svc.ClearAccountSchedulingBlock(account.ID) + require.False(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1)) + + for i := 0; i < openAIOAuth429StormThreshold; i++ { + svc.recordOpenAIOAuth429() + } + + require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0)) +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index f29d3607..a8ac391a 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct { loadSkewMilliTotal atomic.Int64 } +type openAIAccountLoadPlan struct { + allCandidates []openAIAccountCandidateScore + candidates []openAIAccountCandidateScore + staleSnapshotCompactRetry []openAIAccountCandidateScore + selectionOrder []openAIAccountCandidateScore + candidateCount int + topK int + loadSkew float64 +} + func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { if m == nil { return @@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if acquireErr == nil && result.Acquired { + if acquireErr == nil && result != nil && result.Acquired { _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) return &AccountSelectionResult{ Account: account, @@ -586,6 +596,231 @@ func buildOpenAIWeightedSelectionOrder( return order } +func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan( + req OpenAIAccountScheduleRequest, + filtered []*Account, + loadMap map[int64]*AccountLoadInfo, +) openAIAccountLoadPlan { + allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + errorRate, ttft, hasTTFT := 0.0, 0.0, false + if s.stats != nil { + errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID) + } + allCandidates = append(allCandidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + + candidates := allCandidates + staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates)) + if req.RequireCompact { + candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates)) + for _, candidate := range allCandidates { + if openAICompactSupportTier(candidate.account) == 0 { + staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate) + continue + } + candidates = append(candidates, candidate) + } + } + + plan := openAIAccountLoadPlan{ + allCandidates: allCandidates, + candidates: candidates, + staleSnapshotCompactRetry: staleSnapshotCompactRetry, + candidateCount: len(candidates), + } + if len(candidates) == 0 { + plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan) + return plan + } + + minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + for _, candidate := range candidates { + if candidate.account.Priority < minPriority { + minPriority = candidate.account.Priority + } + if candidate.account.Priority > maxPriority { + maxPriority = candidate.account.Priority + } + if candidate.loadInfo.WaitingCount > maxWaiting { + maxWaiting = candidate.loadInfo.WaitingCount + } + if candidate.hasTTFT && candidate.ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = candidate.ttft, candidate.ttft + hasTTFTSample = true + } else { + if candidate.ttft < minTTFT { + minTTFT = candidate.ttft + } + if candidate.ttft > maxTTFT { + maxTTFT = candidate.ttft + } + } + } + loadRate := float64(candidate.loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + } + plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + plan.candidates = candidates + + plan.topK = s.service.openAIWSLBTopK() + if plan.topK > len(candidates) { + plan.topK = len(candidates) + } + if plan.topK <= 0 { + plan.topK = 1 + } + + plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan) + return plan +} + +func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder( + req OpenAIAccountScheduleRequest, + plan openAIAccountLoadPlan, +) []openAIAccountCandidateScore { + buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 || plan.topK <= 0 { + return nil + } + groupTopK := plan.topK + if groupTopK > len(pool) { + groupTopK = len(pool) + } + ranked := selectTopKOpenAICandidates(pool, groupTopK) + return buildOpenAIWeightedSelectionOrder(ranked, req) + } + + if req.RequireCompact { + supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates)) + unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates)) + for _, candidate := range plan.candidates { + switch openAICompactSupportTier(candidate.account) { + case 2: + supported = append(supported, candidate) + case 1: + unknown = append(unknown, candidate) + } + } + selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates)) + selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...) + selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...) + if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil { + selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...) + } + return selectionOrder + } + + return buildSelectionOrder(plan.candidates) +} + +func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 { + return nil + } + ordered := append([]openAIAccountCandidateScore(nil), pool...) + sort.SliceStable(ordered, func(i, j int) bool { + a, b := ordered[i], ordered[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount { + return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + return ordered +} + +func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder( + ctx context.Context, + req OpenAIAccountScheduleRequest, + selectionOrder []openAIAccountCandidateScore, +) (*AccountSelectionResult, bool, error) { + compactBlocked := false + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { + continue + } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { + continue + } + if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { + compactBlocked = true + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if acquireErr != nil { + return nil, compactBlocked, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, compactBlocked, nil + } + } + return nil, compactBlocked, nil +} + func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ctx context.Context, req OpenAIAccountScheduleRequest, @@ -616,8 +851,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if !account.IsSchedulable() || !account.IsOpenAI() { continue } + if s.service.isOpenAIAccountRuntimeBlocked(account) { + continue + } // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() { + s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set") _ = s.service.accountRepo.SetError(ctx, account.ID, fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue @@ -645,208 +884,46 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } } - allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered)) - for _, account := range filtered { - loadInfo := loadMap[account.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: account.ID} - } - errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) - allCandidates = append(allCandidates, openAIAccountCandidateScore{ - account: account, - loadInfo: loadInfo, - errorRate: errorRate, - ttft: ttft, - hasTTFT: hasTTFT, - }) + plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap) + candidateCount := plan.candidateCount + topK := plan.topK + loadSkew := plan.loadSkew + selectionOrder := plan.selectionOrder + if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 { + return nil, 0, 0, 0, ErrNoAvailableCompactAccounts } - - // Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用 - // 时作为最后兜底(snapshot 可能已陈旧)。 - candidates := allCandidates - staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates)) - if req.RequireCompact { - candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates)) - for _, candidate := range allCandidates { - if openAICompactSupportTier(candidate.account) == 0 { - staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate) - continue - } - candidates = append(candidates, candidate) - } - if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 { - return nil, 0, 0, 0, ErrNoAvailableCompactAccounts - } - } - - candidateCount := len(candidates) - loadSkew := 0.0 - if len(candidates) > 0 { - minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority - maxWaiting := 1 - loadRateSum := 0.0 - loadRateSumSquares := 0.0 - minTTFT, maxTTFT := 0.0, 0.0 - hasTTFTSample := false - for _, candidate := range candidates { - if candidate.account.Priority < minPriority { - minPriority = candidate.account.Priority - } - if candidate.account.Priority > maxPriority { - maxPriority = candidate.account.Priority - } - if candidate.loadInfo.WaitingCount > maxWaiting { - maxWaiting = candidate.loadInfo.WaitingCount - } - if candidate.hasTTFT && candidate.ttft > 0 { - if !hasTTFTSample { - minTTFT, maxTTFT = candidate.ttft, candidate.ttft - hasTTFTSample = true - } else { - if candidate.ttft < minTTFT { - minTTFT = candidate.ttft - } - if candidate.ttft > maxTTFT { - maxTTFT = candidate.ttft - } - } - } - loadRate := float64(candidate.loadInfo.LoadRate) - loadRateSum += loadRate - loadRateSumSquares += loadRate * loadRate - } - loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) - - weights := s.service.openAIWSSchedulerWeights() - for i := range candidates { - item := &candidates[i] - priorityFactor := 1.0 - if maxPriority > minPriority { - priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) - } - loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) - queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) - errorFactor := 1 - clamp01(item.errorRate) - ttftFactor := 0.5 - if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { - ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) - } - - item.score = weights.Priority*priorityFactor + - weights.Load*loadFactor + - weights.Queue*queueFactor + - weights.ErrorRate*errorFactor + - weights.TTFT*ttftFactor - } - } - - topK := 0 - if len(candidates) > 0 { - topK = s.service.openAIWSLBTopK() - if topK > len(candidates) { - topK = len(candidates) - } - if topK <= 0 { - topK = 1 - } - } - - buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { - if len(pool) == 0 || topK <= 0 { - return nil - } - groupTopK := topK - if groupTopK > len(pool) { - groupTopK = len(pool) - } - ranked := selectTopKOpenAICandidates(pool, groupTopK) - return buildOpenAIWeightedSelectionOrder(ranked, req) - } - sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { - if len(pool) == 0 { - return nil - } - ordered := append([]openAIAccountCandidateScore(nil), pool...) - sort.SliceStable(ordered, func(i, j int) bool { - a, b := ordered[i], ordered[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount { - return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - return ordered - } - - selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates)) - if req.RequireCompact { - supported := make([]openAIAccountCandidateScore, 0, len(candidates)) - unknown := make([]openAIAccountCandidateScore, 0, len(candidates)) - for _, candidate := range candidates { - switch openAICompactSupportTier(candidate.account) { - case 2: - supported = append(supported, candidate) - case 1: - unknown = append(unknown, candidate) - } - } - if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil { - return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts - } - selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...) - selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...) - if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil { - selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...) - } - } else { - selectionOrder = buildSelectionOrder(candidates) + if req.RequireCompact && len(selectionOrder) == 0 && s.service.schedulerSnapshot == nil { + return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts } if len(selectionOrder) == 0 { - return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0) + return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0) } - compactBlocked := false - for i := 0; i < len(selectionOrder); i++ { - candidate := selectionOrder[i] - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { - continue - } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { - continue - } - if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { - compactBlocked = true - continue - } - result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if acquireErr != nil { - return nil, candidateCount, topK, loadSkew, acquireErr - } - if result != nil && result.Acquired { - if req.SessionHash != "" { - _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder) + if acquireErr != nil { + return nil, candidateCount, topK, loadSkew, acquireErr + } + if result != nil { + return result, candidateCount, topK, loadSkew, nil + } + + if s.service.concurrencyService != nil { + if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil { + freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap) + if len(freshPlan.selectionOrder) > 0 { + freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder) + if freshAcquireErr != nil { + return nil, candidateCount, topK, loadSkew, freshAcquireErr + } + if freshResult != nil { + return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil + } + compactBlocked = compactBlocked || freshCompactBlocked + selectionOrder = freshPlan.selectionOrder + candidateCount = freshPlan.candidateCount + topK = freshPlan.topK + loadSkew = freshPlan.loadSkew } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, candidateCount, topK, loadSkew, nil } } @@ -893,6 +970,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C if account == nil { return false } + if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) { + return false + } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { return false } diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index f8b23a28..27eb211e 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index c585290e..ad6d3e8d 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 6d74f7dd..336a7d79 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback.go b/backend/internal/service/openai_gateway_responses_chat_fallback.go index 1d28a9c2..c3ebc35c 100644 --- a/backend/internal/service/openai_gateway_responses_chat_fallback.go +++ b/backend/internal/service/openai_gateway_responses_chat_fallback.go @@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f13c4748..e47d1989 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -354,6 +354,9 @@ type OpenAIGatewayService struct { openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time + openaiOAuth429WindowStartUnixNano atomic.Int64 + openaiOAuth429WindowCount atomic.Int64 openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter codexSnapshotThrottle *accountWriteThrottle @@ -417,6 +420,12 @@ func NewOpenAIGatewayService( responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } + if rateLimitService != nil { + rateLimitService.SetAccountRuntimeBlocker(svc) + } + if openAITokenProvider != nil { + openAITokenProvider.SetAccountRuntimeBlocker(svc) + } svc.logOpenAIWSModeBootstrap() return svc } @@ -1381,13 +1390,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) } + hydrated, err := s.hydrateSelectedAccount(ctx, selected) + if err != nil { + return nil, err + } + // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } - return s.hydrateSelectedAccount(ctx, selected) + return hydrated, nil } // tryStickySessionHit 尝试从粘性会话获取账号。 @@ -1430,6 +1444,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { return nil } + if s.isOpenAIAccountRuntimeBlocked(account) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) @@ -1575,8 +1593,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex return nil, err } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + if err == nil && result != nil && result.Acquired { + return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc) } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) @@ -1627,13 +1645,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if s.isOpenAIAccountRuntimeBlocked(account) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc) + if selectErr != nil { + return nil, selectErr + } _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return selection, nil } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) @@ -1665,6 +1689,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex if !acc.IsSchedulable() { continue } + if s.isOpenAIAccountRuntimeBlocked(acc) { + continue + } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } @@ -1687,6 +1714,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex }) } + tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) { + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) == 0 { + return nil, false, nil + } + + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + shuffleWithinSortGroups(available) + + selectionOrder := make([]accountWithLoad, 0, len(available)) + if requireCompact { + appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { + for _, item := range available { + if openAICompactSupportTier(item.account) == tier { + out = append(out, item) + } + } + return out + } + selectionOrder = appendTier(selectionOrder, 2) + selectionOrder = appendTier(selectionOrder, 1) + // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 + // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 + selectionOrder = appendTier(selectionOrder, 0) + } else { + selectionOrder = append(selectionOrder, available...) + } + + for _, item := range selectionOrder { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc) + if selectErr != nil { + return nil, true, selectErr + } + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return selection, true, nil + } + } + return nil, true, nil + } + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { ordered := append([]*Account(nil), candidates...) @@ -1707,87 +1820,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex continue } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc) + if selectErr != nil { + return nil, selectErr + } if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + return selection, nil } } } else { - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - shuffleWithinSortGroups(available) - - selectionOrder := make([]accountWithLoad, 0, len(available)) - if requireCompact { - appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { - for _, item := range available { - if openAICompactSupportTier(item.account) == tier { - out = append(out, item) - } - } - return out - } - selectionOrder = appendTier(selectionOrder, 2) - selectionOrder = appendTier(selectionOrder, 1) - // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 - // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 - selectionOrder = appendTier(selectionOrder, 0) - } else { - selectionOrder = append(selectionOrder, available...) - } - - for _, item := range selectionOrder { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) - if fresh == nil { - continue - } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) - if fresh == nil { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue - } - result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) - } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil { + return nil, selectErr + } else if selection != nil { + return selection, nil + } else if attempted { + if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil { + if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil { + return nil, selectErr + } else if selection != nil { + return selection, nil } } } @@ -1868,6 +1922,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { return nil } + if s.isOpenAIAccountRuntimeBlocked(fresh) { + return nil + } return fresh } @@ -1889,6 +1946,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { return nil } + if s.isOpenAIAccountRuntimeBlocked(latest) { + return nil + } return latest } @@ -1935,6 +1995,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account * }, nil } +func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) { + selection, err := s.newSelectionResult(ctx, account, true, release, nil) + if err != nil && release != nil { + release() + } + return selection, err +} + func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling @@ -1996,7 +2064,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } // Forward forwards request to OpenAI API @@ -3278,9 +3346,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -3321,12 +3387,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - // Passthrough mode preserves the raw upstream error response, but runtime - // account state still needs to be updated so sticky routing can stop - // reusing a freshly rate-limited account. - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + // 透传模式保留原始上游错误响应,但运行态账号状态仍需更新, + // 避免粘性路由继续复用刚被限流的账号。 + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4075,10 +4138,7 @@ func (s *OpenAIGatewayService) handleErrorResponse( } // Handle upstream error (mark account status) - shouldDisable := false - if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) kind := "http_error" if shouldDisable { kind = "failover" @@ -4210,12 +4270,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( } // Track rate limits and decide whether to trigger secondary failover. - shouldDisable := false - if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError( - c.Request.Context(), account, resp.StatusCode, resp.Header, body, - ) - } + shouldDisable := s.handleOpenAIAccountUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) kind := "http_error" if shouldDisable { kind = "failover" diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 5b55d200..94791189 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -80,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + runtimeBlocker AccountRuntimeBlocker metrics *openAITokenRuntimeMetricsStore refreshAPI *OAuthRefreshAPI executor OAuthRefreshExecutor @@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { p.refreshPolicy = policy } +func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + p.runtimeBlocker = blocker +} + func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { if p == nil { return OpenAITokenRuntimeMetrics{} @@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account if p == nil || p.accountRepo == nil || account == nil { return } + if p.runtimeBlocker != nil { + p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token") + } bgCtx := context.Background() if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil { slog.Warn("openai_token_provider.set_error_failed", diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index df2f0f3e..fb506523 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) cache.getErr = errors.New("simulated cache miss") provider := NewOpenAITokenProvider(repo, cache, nil) + blocker := &runtimeBlockRecorder{} + provider.SetAccountRuntimeBlocker(blocker) token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) @@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once") require.Contains(t, repo.lastErrorMsg, "refresh_token is missing") + require.Len(t, blocker.accounts, 1) + require.Equal(t, account.ID, blocker.accounts[0].ID) + require.Equal(t, "missing_refresh_token", blocker.reasons[0]) } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 920a2239..700dbedf 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { return } - s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) + s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) } func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 892d9aca..c3b160e7 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -28,10 +28,16 @@ type RateLimitService struct { openAI403CounterCache OpenAI403CounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator + runtimeBlocker AccountRuntimeBlocker usageCacheMu sync.RWMutex usageCache map[int64]*geminiUsageCacheEntry } +type AccountRuntimeBlocker interface { + BlockAccountScheduling(account *Account, until time.Time, reason string) + ClearAccountSchedulingBlock(accountID int64) +} + // SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。 type SuccessfulTestRecoveryResult struct { ClearedError bool @@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + s.runtimeBlocker = blocker +} + +func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) { + if s == nil || s.runtimeBlocker == nil || account == nil { + return + } + s.runtimeBlocker.BlockAccountScheduling(account, until, reason) +} + +func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) { + if s == nil || s.runtimeBlocker == nil || accountID <= 0 { + return + } + s.runtimeBlocker.ClearAccountSchedulingBlock(accountID) +} + // ErrorPolicyResult 表示错误策略检查的结果 type ErrorPolicyResult int @@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc cooldownMinutes = 10 } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + s.notifyAccountSchedulingBlocked(account, until, "oauth_401") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil { slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) } @@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { + s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error") if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err) return @@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute) reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg) + s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) s.handleAuthError(ctx, account, msg) @@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg + s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code") if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err) return @@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody) s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { + s.notifyAccountSchedulingBlocked(account, *resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 if result := calculateAnthropic429ResetTime(headers); result != nil { + s.notifyAccountSchedulingBlocked(account, result.resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 尝试解析 OpenAI 的 usage_limit_reached 错误 if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil { resetTime := time.Unix(*resetAt, 0) + s.notifyAccountSchedulingBlocked(account, resetTime, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 尝试解析 Gemini 格式(用于其他平台) if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil { resetTime := time.Unix(*resetAt, 0) + s.notifyAccountSchedulingBlocked(account, resetTime, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head resetAt := time.Unix(ts, 0) // 标记限流状态 + s.notifyAccountSchedulingBlocked(account, resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun resetAt := time.Now().Add(cooldown) slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String()) + s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } @@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) { } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + s.notifyAccountSchedulingBlocked(account, until, "529") if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { slog.Warn("overload_set_failed", "account_id", account.ID, "error", err) return @@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) } } s.ResetOpenAI403Counter(ctx, accountID) + s.notifyAccountSchedulingBlockCleared(accountID) return nil } @@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in } if result.ClearedError || result.ClearedRateLimit { s.ResetOpenAI403Counter(ctx, accountID) + if result.ClearedError && !result.ClearedRateLimit { + s.notifyAccountSchedulingBlockCleared(accountID) + } } return result, nil @@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err) } + s.notifyAccountSchedulingBlockCleared(accountID) return nil } @@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account reason = strings.TrimSpace(state.ErrorMessage) } + s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err) return false @@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, reason = state.ErrorMessage } + s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err) return false @@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool { errorMsg := "Stream data interval timeout (repeated failures) for model: " + model + s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error") if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err) return false diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go index 2fd11b71..9d4b2714 100644 --- a/backend/internal/service/ratelimit_service_403_test.go +++ b/backend/internal/service/ratelimit_service_403_test.go @@ -6,16 +6,36 @@ import ( "context" "net/http" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) +type runtimeBlockRecorder struct { + accounts []*Account + until []time.Time + reasons []string + clearedIDs []int64 +} + +func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) { + r.accounts = append(r.accounts, account) + r.until = append(r.until, until) + r.reasons = append(r.reasons, reason) +} + +func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) { + r.clearedIDs = append(r.clearedIDs, accountID) +} + func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) { repo := &rateLimitAccountRepoStub{} counter := &openAI403CounterCacheStub{counts: []int64{1}} + blocker := &runtimeBlockRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetOpenAI403CounterCache(counter) + service.SetAccountRuntimeBlocker(blocker) account := &Account{ ID: 301, Platform: PlatformOpenAI, @@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable require.Equal(t, 1, repo.tempCalls) require.Contains(t, repo.lastTempReason, "temporary edge rejection") require.Contains(t, repo.lastTempReason, "(1/3)") + require.Len(t, blocker.accounts, 1) + require.Equal(t, account.ID, blocker.accounts[0].ID) + require.Equal(t, "openai_403_temp", blocker.reasons[0]) + require.True(t, blocker.until[0].After(time.Now())) } func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) { diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go index 1d7a02fc..e8d312f9 100644 --- a/backend/internal/service/ratelimit_service_clear_test.go +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi }, } cache := &tempUnschedCacheRecorder{} + blocker := &runtimeBlockRecorder{} svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + svc.SetAccountRuntimeBlocker(blocker) result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42) require.NoError(t, err) @@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi require.Equal(t, 1, repo.clearModelRateLimitCalls) require.Equal(t, 1, repo.clearTempUnschedCalls) require.Equal(t, []int64{42}, cache.deletedIDs) + require.Equal(t, []int64{42}, blocker.clearedIDs) } func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) { diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go index 0b32c2ad..778cab23 100644 --- a/backend/internal/service/scheduler_snapshot_hydration_test.go +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul } } +func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) { + cache := &snapshotHydrationCache{ + accounts: map[int64]*Account{}, + } + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil) + svc := &OpenAIGatewayService{ + schedulerSnapshot: schedulerSnapshot, + } + releaseCalls := 0 + + selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() { + releaseCalls++ + }) + + if err == nil { + t.Fatalf("expected hydration error") + } + if selection != nil { + t.Fatalf("expected nil selection on hydration error") + } + if releaseCalls != 1 { + t.Fatalf("expected release to be called once, got %d", releaseCalls) + } +} + func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { cache := &snapshotHydrationCache{ snapshot: []*Account{ diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 96428847..c5f1991b 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -27,6 +27,7 @@ type TokenRefreshService struct { schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存 refreshAPI *OAuthRefreshAPI // 统一刷新 API + runtimeBlocker AccountRuntimeBlocker // OpenAI privacy: 刷新成功后检查并设置 training opt-out privacyClientFactory PrivacyClientFactory @@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) { s.refreshPolicy = policy } +func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + s.runtimeBlocker = blocker +} + +func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) { + if s == nil || s.runtimeBlocker == nil || account == nil { + return + } + s.runtimeBlocker.BlockAccountScheduling(account, until, reason) +} + +func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) { + if s == nil || s.runtimeBlocker == nil || accountID <= 0 { + return + } + s.runtimeBlocker.ClearAccountSchedulingBlock(accountID) +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { @@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc // 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回 if isNonRetryableRefreshError(err) { errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) + s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable") if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { slog.Error("token_refresh.set_error_status_failed", "account_id", account.ID, @@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc // 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试) until := time.Now().Add(tokenRefreshTempUnschedDuration) reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr) + s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted") if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil { slog.Warn("token_refresh.set_temp_unschedulable_failed", "account_id", account.ID, @@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A ) } else { slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) + s.notifyAccountSchedulingBlockCleared(account.ID) } } // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) @@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A ) } else { slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) + s.notifyAccountSchedulingBlockCleared(account.ID) } // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 if s.tempUnschedCache != nil { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 26bda593..43748aa5 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -59,6 +59,7 @@ func ProvideTokenRefreshService( privacyClientFactory PrivacyClientFactory, proxyRepo ProxyRepository, refreshAPI *OAuthRefreshAPI, + runtimeBlocker AccountRuntimeBlocker, ) *TokenRefreshService { svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) // 注入 OpenAI privacy opt-out 依赖 @@ -67,6 +68,7 @@ func ProvideTokenRefreshService( svc.SetRefreshAPI(refreshAPI) // 调用侧显式注入后台刷新策略,避免策略漂移 svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy()) + svc.SetAccountRuntimeBlocker(runtimeBlocker) svc.Start() return svc } @@ -183,6 +185,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err) } if cfg != nil { + svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond) svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) } return svc @@ -455,6 +458,7 @@ var ProviderSet = wire.NewSet( NewAdminService, NewGatewayService, NewOpenAIGatewayService, + wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)), NewOAuthService, NewOpenAIOAuthService, NewGeminiOAuthService, diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 0d61b710..8e9b0e3b 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -405,6 +405,9 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true + # Tiny in-process TTL for batch load reads in milliseconds (0 disables) + # 调度批量负载读取的进程内短缓存 TTL(毫秒,0 表示禁用) + load_batch_cache_ttl_ms: 200 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s