fix: optimize OpenAI account cooldown scheduling
This commit is contained in:
parent
f59d9a5f8e
commit
1e406fed52
@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
@ -173,24 +188,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@ -261,7 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
|
||||
@ -1009,7 +1009,8 @@ type GatewaySchedulingConfig struct {
|
||||
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||
|
||||
// 负载计算
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"`
|
||||
// 快照桶读取时的 MGET 分块大小
|
||||
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
|
||||
// 快照重建时的缓存写入分块大小
|
||||
@ -1828,6 +1829,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200)
|
||||
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
|
||||
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
|
||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||
@ -2634,6 +2636,9 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
|
||||
}
|
||||
if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
|
||||
}
|
||||
|
||||
@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 {
|
||||
t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling load batch cache ttl",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 },
|
||||
wantErr: "gateway.scheduling.load_batch_cache_ttl_ms",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox poll",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
||||
|
||||
@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
}()
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
}()
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
if channelMappingMsg.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
}()
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||
}()
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -258,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -1258,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@ -531,6 +531,7 @@ type adminServiceImpl struct {
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@ -556,6 +557,7 @@ func NewAdminService(
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
runtimeBlocker AccountRuntimeBlocker,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@ -575,6 +577,7 @@ func NewAdminService(
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
privacyClientFactory: privacyClientFactory,
|
||||
runtimeBlocker: runtimeBlocker,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.runtimeBlocker != nil {
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(id)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
|
||||
@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
||||
TempUnschedulableReason: "missing refresh token",
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker}
|
||||
|
||||
updated, err := svc.ClearAccountError(context.Background(), 31)
|
||||
require.NoError(t, err)
|
||||
@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
||||
require.Nil(t, updated.RateLimitResetAt)
|
||||
require.Nil(t, updated.TempUnschedulableUntil)
|
||||
require.Empty(t, updated.TempUnschedulableReason)
|
||||
require.Equal(t, []int64{31}, blocker.clearedIDs)
|
||||
}
|
||||
|
||||
@ -3,13 +3,17 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
// 默认等待队列额外槽位
|
||||
defaultExtraWaitSlots = 20
|
||||
|
||||
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
|
||||
accountLoadBatchFetchTimeout = 3 * time.Second
|
||||
maxAccountLoadBatchCacheEntries = 256
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
// ConcurrencyService 管理账号和用户的并发限制。
|
||||
type ConcurrencyService struct {
|
||||
cache ConcurrencyCache
|
||||
|
||||
accountLoadCacheTTL atomic.Int64
|
||||
accountLoadCacheMu sync.RWMutex
|
||||
accountLoadCache map[string]cachedAccountLoadBatch
|
||||
accountLoadGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
type cachedAccountLoadBatch struct {
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewConcurrencyService 创建并发控制服务。
|
||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
svc := &ConcurrencyService{
|
||||
cache: cache,
|
||||
accountLoadCache: make(map[string]cachedAccountLoadBatch),
|
||||
}
|
||||
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
|
||||
return svc
|
||||
}
|
||||
|
||||
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
|
||||
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.accountLoadCacheTTL.Store(int64(ttl))
|
||||
if ttl <= 0 {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int {
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
// GetAccountsLoadBatch 批量获取账号负载信息。
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return s.getAccountsLoadBatch(ctx, accounts, true)
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return s.getAccountsLoadBatch(ctx, accounts, false)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
|
||||
ttl := time.Duration(s.accountLoadCacheTTL.Load())
|
||||
if !allowCache || ttl <= 0 {
|
||||
return s.fetchAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
key := accountLoadBatchCacheKey(accounts)
|
||||
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
|
||||
now := time.Now()
|
||||
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
|
||||
return cached, nil
|
||||
}
|
||||
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
|
||||
if fetchErr != nil {
|
||||
return nil, fetchErr
|
||||
}
|
||||
cached := cloneAccountLoadMap(loadMap)
|
||||
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
|
||||
return cached, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loadMap, _ := value.(map[int64]*AccountLoadInfo)
|
||||
if loadMap == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
baseCtx := context.Background()
|
||||
if ctx != nil {
|
||||
baseCtx = context.WithoutCancel(ctx)
|
||||
}
|
||||
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
|
||||
defer cancel()
|
||||
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
|
||||
s.accountLoadCacheMu.RLock()
|
||||
cached, ok := s.accountLoadCache[key]
|
||||
s.accountLoadCacheMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if !now.Before(cached.expiresAt) {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
|
||||
delete(s.accountLoadCache, key)
|
||||
}
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
return cached.loadMap, true
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
if s.accountLoadCache == nil {
|
||||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||
}
|
||||
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||
now := time.Now()
|
||||
for cacheKey, cached := range s.accountLoadCache {
|
||||
if !now.Before(cached.expiresAt) {
|
||||
delete(s.accountLoadCache, cacheKey)
|
||||
}
|
||||
}
|
||||
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||
for cacheKey := range s.accountLoadCache {
|
||||
delete(s.accountLoadCache, cacheKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.accountLoadCache[key] = cachedAccountLoadBatch{
|
||||
loadMap: loadMap,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
|
||||
hash := sha256.New()
|
||||
var buf [16]byte
|
||||
for _, account := range accounts {
|
||||
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
|
||||
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
|
||||
_, _ = hash.Write(buf[:])
|
||||
}
|
||||
sum := hash.Sum(nil)
|
||||
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
|
||||
}
|
||||
|
||||
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
|
||||
if len(loadMap) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
|
||||
for accountID, loadInfo := range loadMap {
|
||||
if loadInfo == nil {
|
||||
clone[accountID] = nil
|
||||
continue
|
||||
}
|
||||
copied := *loadInfo
|
||||
clone[accountID] = &copied
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// GetUsersLoadBatch returns load info for multiple users.
|
||||
|
||||
@ -7,7 +7,9 @@ import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct {
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
releasedRequestIDs []string
|
||||
loadBatchCalls atomic.Int64
|
||||
}
|
||||
|
||||
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
||||
@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
c.loadBatchCalls.Add(1)
|
||||
return c.loadBatch, c.loadBatchErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{
|
||||
loadBatch: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||
},
|
||||
}
|
||||
svc := NewConcurrencyService(cache)
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||
|
||||
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||
first, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, first[int64(1)].CurrentConcurrency)
|
||||
|
||||
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||
second, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, second[int64(1)].CurrentConcurrency)
|
||||
require.Equal(t, int64(1), cache.loadBatchCalls.Load())
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{
|
||||
loadBatch: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||
},
|
||||
}
|
||||
svc := NewConcurrencyService(cache)
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||
|
||||
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||
_, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||
fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency)
|
||||
require.Equal(t, int64(2), cache.loadBatchCalls.Load())
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountStateUpdateTimeout = 5 * time.Second
|
||||
openAIOAuth429FallbackCooldown = 5 * time.Second
|
||||
openAIStopSchedulingBridgeCooldown = 2 * time.Minute
|
||||
openAIOAuth429StormWindow = 10 * time.Second
|
||||
openAIOAuth429StormThreshold = 20
|
||||
openAIOAuth429StormMaxAccountSwitches = 1
|
||||
)
|
||||
|
||||
func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, openAIAccountStateUpdateTimeout)
|
||||
}
|
||||
|
||||
func isOpenAIOAuthAccount(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func isOpenAIAccount(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
|
||||
stateCtx, cancel := openAIAccountStateContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody)
|
||||
}
|
||||
if s == nil || account == nil || s.rateLimitService == nil {
|
||||
return false
|
||||
}
|
||||
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
|
||||
if shouldDisable {
|
||||
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
if s == nil || !isOpenAIOAuthAccount(account) {
|
||||
return
|
||||
}
|
||||
s.recordOpenAIOAuth429()
|
||||
|
||||
cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown)
|
||||
if s.rateLimitService != nil {
|
||||
if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) {
|
||||
cooldownUntil = *resetAt
|
||||
} else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil {
|
||||
if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) {
|
||||
cooldownUntil = resetAt
|
||||
}
|
||||
} else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 {
|
||||
cooldownUntil = time.Now().Add(cooldown)
|
||||
}
|
||||
}
|
||||
s.BlockAccountScheduling(account, cooldownUntil, "429")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||
if s == nil || !isOpenAIAccount(account) {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
blockUntil := until
|
||||
if blockUntil.IsZero() || !blockUntil.After(now) {
|
||||
blockUntil = now.Add(openAIStopSchedulingBridgeCooldown)
|
||||
}
|
||||
|
||||
for {
|
||||
current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
if !loaded {
|
||||
actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil)
|
||||
if !stored {
|
||||
return
|
||||
}
|
||||
current = actual
|
||||
}
|
||||
|
||||
currentUntil, ok := current.(time.Time)
|
||||
if !ok || currentUntil.IsZero() {
|
||||
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if currentUntil.After(blockUntil) {
|
||||
return
|
||||
}
|
||||
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(accountID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool {
|
||||
if s == nil || !isOpenAIAccount(account) {
|
||||
return false
|
||||
}
|
||||
value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
cooldownUntil, ok := value.(time.Time)
|
||||
if !ok || cooldownUntil.IsZero() {
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||
return false
|
||||
}
|
||||
if time.Now().Before(cooldownUntil) {
|
||||
return true
|
||||
}
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) recordOpenAIOAuth429() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||
if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||
if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) {
|
||||
s.openaiOAuth429WindowCount.Store(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.openaiOAuth429WindowCount.Add(1)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||
if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||
return false
|
||||
}
|
||||
return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool {
|
||||
if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches {
|
||||
return false
|
||||
}
|
||||
if !isOpenAIOAuthAccount(account) {
|
||||
return false
|
||||
}
|
||||
return s.isOpenAIOAuth429Storm()
|
||||
}
|
||||
@ -0,0 +1,101 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil)
|
||||
apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil)
|
||||
|
||||
require.False(t, shouldDisable)
|
||||
require.False(t, apiKeyShouldDisable)
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) {
|
||||
gateway := &OpenAIGatewayService{}
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
rateLimitService.SetAccountRuntimeBlocker(gateway)
|
||||
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
|
||||
shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
longUntil := time.Now().Add(10 * time.Minute)
|
||||
|
||||
svc.BlockAccountScheduling(account, longUntil, "oauth_401")
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||
|
||||
value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
require.True(t, ok)
|
||||
actualUntil, ok := value.(time.Time)
|
||||
require.True(t, ok)
|
||||
require.WithinDuration(t, longUntil, actualUntil, time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429")
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
|
||||
svc.ClearAccountSchedulingBlock(account.ID)
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||
|
||||
for i := 0; i < openAIOAuth429StormThreshold; i++ {
|
||||
svc.recordOpenAIOAuth429()
|
||||
}
|
||||
|
||||
require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0))
|
||||
}
|
||||
@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct {
|
||||
loadSkewMilliTotal atomic.Int64
|
||||
}
|
||||
|
||||
type openAIAccountLoadPlan struct {
|
||||
allCandidates []openAIAccountCandidateScore
|
||||
candidates []openAIAccountCandidateScore
|
||||
staleSnapshotCompactRetry []openAIAccountCandidateScore
|
||||
selectionOrder []openAIAccountCandidateScore
|
||||
candidateCount int
|
||||
topK int
|
||||
loadSkew float64
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||
if m == nil {
|
||||
return
|
||||
@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
if acquireErr == nil && result != nil && result.Acquired {
|
||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@ -586,6 +596,231 @@ func buildOpenAIWeightedSelectionOrder(
|
||||
return order
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan(
|
||||
req OpenAIAccountScheduleRequest,
|
||||
filtered []*Account,
|
||||
loadMap map[int64]*AccountLoadInfo,
|
||||
) openAIAccountLoadPlan {
|
||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
errorRate, ttft, hasTTFT := 0.0, 0.0, false
|
||||
if s.stats != nil {
|
||||
errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID)
|
||||
}
|
||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
|
||||
candidates := allCandidates
|
||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
for _, candidate := range allCandidates {
|
||||
if openAICompactSupportTier(candidate.account) == 0 {
|
||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
plan := openAIAccountLoadPlan{
|
||||
allCandidates: allCandidates,
|
||||
candidates: candidates,
|
||||
staleSnapshotCompactRetry: staleSnapshotCompactRetry,
|
||||
candidateCount: len(candidates),
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
for _, candidate := range candidates {
|
||||
if candidate.account.Priority < minPriority {
|
||||
minPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.account.Priority > maxPriority {
|
||||
maxPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = candidate.loadInfo.WaitingCount
|
||||
}
|
||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if candidate.ttft < minTTFT {
|
||||
minTTFT = candidate.ttft
|
||||
}
|
||||
if candidate.ttft > maxTTFT {
|
||||
maxTTFT = candidate.ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
}
|
||||
plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
plan.candidates = candidates
|
||||
|
||||
plan.topK = s.service.openAIWSLBTopK()
|
||||
if plan.topK > len(candidates) {
|
||||
plan.topK = len(candidates)
|
||||
}
|
||||
if plan.topK <= 0 {
|
||||
plan.topK = 1
|
||||
}
|
||||
|
||||
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder(
|
||||
req OpenAIAccountScheduleRequest,
|
||||
plan openAIAccountLoadPlan,
|
||||
) []openAIAccountCandidateScore {
|
||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 || plan.topK <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupTopK := plan.topK
|
||||
if groupTopK > len(pool) {
|
||||
groupTopK = len(pool)
|
||||
}
|
||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||
}
|
||||
|
||||
if req.RequireCompact {
|
||||
supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||
unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||
for _, candidate := range plan.candidates {
|
||||
switch openAICompactSupportTier(candidate.account) {
|
||||
case 2:
|
||||
supported = append(supported, candidate)
|
||||
case 1:
|
||||
unknown = append(unknown, candidate)
|
||||
}
|
||||
}
|
||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates))
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||
if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||
selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...)
|
||||
}
|
||||
return selectionOrder
|
||||
}
|
||||
|
||||
return buildSelectionOrder(plan.candidates)
|
||||
}
|
||||
|
||||
func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 {
|
||||
return nil
|
||||
}
|
||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
a, b := ordered[i], ordered[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
selectionOrder []openAIAccountCandidateScore,
|
||||
) (*AccountSelectionResult, bool, error) {
|
||||
compactBlocked := false
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, compactBlocked, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, compactBlocked, nil
|
||||
}
|
||||
}
|
||||
return nil, compactBlocked, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
@ -616,8 +851,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
if s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
|
||||
s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set")
|
||||
_ = s.service.accountRepo.SetError(ctx, account.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
@ -645,208 +884,46 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}
|
||||
}
|
||||
|
||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap)
|
||||
candidateCount := plan.candidateCount
|
||||
topK := plan.topK
|
||||
loadSkew := plan.loadSkew
|
||||
selectionOrder := plan.selectionOrder
|
||||
if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 {
|
||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
|
||||
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
|
||||
// 时作为最后兜底(snapshot 可能已陈旧)。
|
||||
candidates := allCandidates
|
||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
for _, candidate := range allCandidates {
|
||||
if openAICompactSupportTier(candidate.account) == 0 {
|
||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
|
||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
}
|
||||
|
||||
candidateCount := len(candidates)
|
||||
loadSkew := 0.0
|
||||
if len(candidates) > 0 {
|
||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
for _, candidate := range candidates {
|
||||
if candidate.account.Priority < minPriority {
|
||||
minPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.account.Priority > maxPriority {
|
||||
maxPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = candidate.loadInfo.WaitingCount
|
||||
}
|
||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if candidate.ttft < minTTFT {
|
||||
minTTFT = candidate.ttft
|
||||
}
|
||||
if candidate.ttft > maxTTFT {
|
||||
maxTTFT = candidate.ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
}
|
||||
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
}
|
||||
|
||||
topK := 0
|
||||
if len(candidates) > 0 {
|
||||
topK = s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
}
|
||||
|
||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 || topK <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupTopK := topK
|
||||
if groupTopK > len(pool) {
|
||||
groupTopK = len(pool)
|
||||
}
|
||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||
}
|
||||
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 {
|
||||
return nil
|
||||
}
|
||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
a, b := ordered[i], ordered[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
switch openAICompactSupportTier(candidate.account) {
|
||||
case 2:
|
||||
supported = append(supported, candidate)
|
||||
case 1:
|
||||
unknown = append(unknown, candidate)
|
||||
}
|
||||
}
|
||||
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
|
||||
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
|
||||
}
|
||||
} else {
|
||||
selectionOrder = buildSelectionOrder(candidates)
|
||||
if req.RequireCompact && len(selectionOrder) == 0 && s.service.schedulerSnapshot == nil {
|
||||
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
if len(selectionOrder) == 0 {
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0)
|
||||
}
|
||||
|
||||
compactBlocked := false
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder)
|
||||
if acquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil {
|
||||
return result, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
|
||||
if s.service.concurrencyService != nil {
|
||||
if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil {
|
||||
freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap)
|
||||
if len(freshPlan.selectionOrder) > 0 {
|
||||
freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder)
|
||||
if freshAcquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, freshAcquireErr
|
||||
}
|
||||
if freshResult != nil {
|
||||
return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil
|
||||
}
|
||||
compactBlocked = compactBlocked || freshCompactBlocked
|
||||
selectionOrder = freshPlan.selectionOrder
|
||||
candidateCount = freshPlan.candidateCount
|
||||
topK = freshPlan.topK
|
||||
loadSkew = freshPlan.loadSkew
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -893,6 +970,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -354,6 +354,9 @@ type OpenAIGatewayService struct {
|
||||
openaiAccountStats *openAIAccountRuntimeStats
|
||||
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiOAuth429WindowStartUnixNano atomic.Int64
|
||||
openaiOAuth429WindowCount atomic.Int64
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
codexSnapshotThrottle *accountWriteThrottle
|
||||
@ -417,6 +420,12 @@ func NewOpenAIGatewayService(
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
if rateLimitService != nil {
|
||||
rateLimitService.SetAccountRuntimeBlocker(svc)
|
||||
}
|
||||
if openAITokenProvider != nil {
|
||||
openAITokenProvider.SetAccountRuntimeBlocker(svc)
|
||||
}
|
||||
svc.logOpenAIWSModeBootstrap()
|
||||
return svc
|
||||
}
|
||||
@ -1381,13 +1390,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
||||
}
|
||||
|
||||
hydrated, err := s.hydrateSelectedAccount(ctx, selected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return s.hydrateSelectedAccount(ctx, selected)
|
||||
return hydrated, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
@ -1430,6 +1444,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
return nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
@ -1575,8 +1593,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
@ -1627,13 +1645,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, selectErr
|
||||
}
|
||||
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
return selection, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@ -1665,6 +1689,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(acc) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
@ -1687,6 +1714,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
})
|
||||
}
|
||||
|
||||
tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) {
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
selectionOrder := make([]accountWithLoad, 0, len(available))
|
||||
if requireCompact {
|
||||
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
||||
for _, item := range available {
|
||||
if openAICompactSupportTier(item.account) == tier {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
selectionOrder = appendTier(selectionOrder, 2)
|
||||
selectionOrder = appendTier(selectionOrder, 1)
|
||||
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
||||
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
||||
selectionOrder = appendTier(selectionOrder, 0)
|
||||
} else {
|
||||
selectionOrder = append(selectionOrder, available...)
|
||||
}
|
||||
|
||||
for _, item := range selectionOrder {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, true, selectErr
|
||||
}
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return selection, true, nil
|
||||
}
|
||||
}
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
@ -1707,87 +1820,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, selectErr
|
||||
}
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
selectionOrder := make([]accountWithLoad, 0, len(available))
|
||||
if requireCompact {
|
||||
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
||||
for _, item := range available {
|
||||
if openAICompactSupportTier(item.account) == tier {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
selectionOrder = appendTier(selectionOrder, 2)
|
||||
selectionOrder = appendTier(selectionOrder, 1)
|
||||
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
||||
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
||||
selectionOrder = appendTier(selectionOrder, 0)
|
||||
} else {
|
||||
selectionOrder = append(selectionOrder, available...)
|
||||
}
|
||||
|
||||
for _, item := range selectionOrder {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
||||
if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil {
|
||||
return nil, selectErr
|
||||
} else if selection != nil {
|
||||
return selection, nil
|
||||
} else if attempted {
|
||||
if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil {
|
||||
if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil {
|
||||
return nil, selectErr
|
||||
} else if selection != nil {
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1868,6 +1922,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
@ -1889,6 +1946,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
||||
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||
return nil
|
||||
}
|
||||
return latest
|
||||
}
|
||||
|
||||
@ -1935,6 +1995,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) {
|
||||
selection, err := s.newSelectionResult(ctx, account, true, release, nil)
|
||||
if err != nil && release != nil {
|
||||
release()
|
||||
}
|
||||
return selection, err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
@ -1996,7 +2064,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
@ -3278,9 +3346,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
if s.rateLimitService != nil {
|
||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -3321,12 +3387,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
if s.rateLimitService != nil {
|
||||
// Passthrough mode preserves the raw upstream error response, but runtime
|
||||
// account state still needs to be updated so sticky routing can stop
|
||||
// reusing a freshly rate-limited account.
|
||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
|
||||
// 避免粘性路由继续复用刚被限流的账号。
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -4075,10 +4138,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
@ -4210,12 +4270,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
}
|
||||
|
||||
// Track rate limits and decide whether to trigger secondary failover.
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
)
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
|
||||
@ -80,6 +80,7 @@ type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
metrics *openAITokenRuntimeMetricsStore
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
p.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||
if p == nil {
|
||||
return OpenAITokenRuntimeMetrics{}
|
||||
@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account
|
||||
if p == nil || p.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if p.runtimeBlocker != nil {
|
||||
p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token")
|
||||
}
|
||||
bgCtx := context.Background()
|
||||
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
|
||||
slog.Warn("openai_token_provider.set_error_failed",
|
||||
|
||||
@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
||||
cache.getErr = errors.New("simulated cache miss")
|
||||
|
||||
provider := NewOpenAITokenProvider(repo, cache, nil)
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
provider.SetAccountRuntimeBlocker(blocker)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
||||
|
||||
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
|
||||
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
|
||||
require.Len(t, blocker.accounts, 1)
|
||||
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||
require.Equal(t, "missing_refresh_token", blocker.reasons[0])
|
||||
}
|
||||
|
||||
@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex
|
||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
}
|
||||
|
||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||
|
||||
@ -28,10 +28,16 @@ type RateLimitService struct {
|
||||
openAI403CounterCache OpenAI403CounterCache
|
||||
settingService *SettingService
|
||||
tokenCacheInvalidator TokenCacheInvalidator
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type AccountRuntimeBlocker interface {
|
||||
BlockAccountScheduling(account *Account, until time.Time, reason string)
|
||||
ClearAccountSchedulingBlock(accountID int64)
|
||||
}
|
||||
|
||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||
type SuccessfulTestRecoveryResult struct {
|
||||
ClearedError bool
|
||||
@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
s.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||
}
|
||||
|
||||
// ErrorPolicyResult 表示错误策略检查的结果
|
||||
type ErrorPolicyResult int
|
||||
|
||||
@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
cooldownMinutes = 10
|
||||
}
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "oauth_401")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
|
||||
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account
|
||||
|
||||
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
|
||||
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||
return
|
||||
@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
|
||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
s.notifyAccountSchedulingBlocked(account, *resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
|
||||
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||
s.notifyAccountSchedulingBlocked(account, result.resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 尝试解析 Gemini 格式(用于其他平台)
|
||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
s.notifyAccountSchedulingBlocked(account, resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun
|
||||
|
||||
resetAt := time.Now().Add(cooldown)
|
||||
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
|
||||
s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "529")
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
}
|
||||
}
|
||||
s.ResetOpenAI403Counter(ctx, accountID)
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
|
||||
}
|
||||
if result.ClearedError || result.ClearedRateLimit {
|
||||
s.ResetOpenAI403Counter(ctx, accountID)
|
||||
if result.ClearedError && !result.ClearedRateLimit {
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
|
||||
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
||||
reason = strings.TrimSpace(state.ErrorMessage)
|
||||
}
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
reason = state.ErrorMessage
|
||||
}
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
|
||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
|
||||
@ -6,16 +6,36 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type runtimeBlockRecorder struct {
|
||||
accounts []*Account
|
||||
until []time.Time
|
||||
reasons []string
|
||||
clearedIDs []int64
|
||||
}
|
||||
|
||||
func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||
r.accounts = append(r.accounts, account)
|
||||
r.until = append(r.until, until)
|
||||
r.reasons = append(r.reasons, reason)
|
||||
}
|
||||
|
||||
func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) {
|
||||
r.clearedIDs = append(r.clearedIDs, accountID)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
counter := &openAI403CounterCacheStub{counts: []int64{1}}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetOpenAI403CounterCache(counter)
|
||||
service.SetAccountRuntimeBlocker(blocker)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformOpenAI,
|
||||
@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
|
||||
require.Contains(t, repo.lastTempReason, "(1/3)")
|
||||
require.Len(t, blocker.accounts, 1)
|
||||
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||
require.Equal(t, "openai_403_temp", blocker.reasons[0])
|
||||
require.True(t, blocker.until[0].After(time.Now()))
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
|
||||
|
||||
@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
svc.SetAccountRuntimeBlocker(blocker)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||
require.Equal(t, []int64{42}, blocker.clearedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||
|
||||
@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) {
|
||||
cache := &snapshotHydrationCache{
|
||||
accounts: map[int64]*Account{},
|
||||
}
|
||||
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil)
|
||||
svc := &OpenAIGatewayService{
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
}
|
||||
releaseCalls := 0
|
||||
|
||||
selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() {
|
||||
releaseCalls++
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected hydration error")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection on hydration error")
|
||||
}
|
||||
if releaseCalls != 1 {
|
||||
t.Fatalf("expected release to be called once, got %d", releaseCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
|
||||
cache := &snapshotHydrationCache{
|
||||
snapshot: []*Account{
|
||||
|
||||
@ -27,6 +27,7 @@ type TokenRefreshService struct {
|
||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
|
||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
||||
s.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
s.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
|
||||
if isNonRetryableRefreshError(err) {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable")
|
||||
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
||||
slog.Error("token_refresh.set_error_status_failed",
|
||||
"account_id", account.ID,
|
||||
@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
// 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试)
|
||||
until := time.Now().Add(tokenRefreshTempUnschedDuration)
|
||||
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted")
|
||||
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
|
||||
slog.Warn("token_refresh.set_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
|
||||
@ -59,6 +59,7 @@ func ProvideTokenRefreshService(
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
proxyRepo ProxyRepository,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
runtimeBlocker AccountRuntimeBlocker,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
@ -67,6 +68,7 @@ func ProvideTokenRefreshService(
|
||||
svc.SetRefreshAPI(refreshAPI)
|
||||
// 调用侧显式注入后台刷新策略,避免策略漂移
|
||||
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
||||
svc.SetAccountRuntimeBlocker(runtimeBlocker)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@ -183,6 +185,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
||||
}
|
||||
if cfg != nil {
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond)
|
||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
return svc
|
||||
@ -455,6 +458,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
|
||||
@ -405,6 +405,9 @@ gateway:
|
||||
# Enable batch load calculation for scheduling
|
||||
# 启用调度批量负载计算
|
||||
load_batch_enabled: true
|
||||
# Tiny in-process TTL for batch load reads in milliseconds (0 disables)
|
||||
# 调度批量负载读取的进程内短缓存 TTL(毫秒,0 表示禁用)
|
||||
load_batch_cache_ttl_ms: 200
|
||||
# Slot cleanup interval (duration)
|
||||
# 并发槽位清理周期(时间段)
|
||||
slot_cleanup_interval: 30s
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user