fix: optimize OpenAI account cooldown scheduling
This commit is contained in:
parent
f59d9a5f8e
commit
1e406fed52
@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
privacyClientFactory := providePrivacyClientFactory()
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
rpmCache := repository.NewRPMCache(redisClient)
|
if err != nil {
|
||||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
return nil, err
|
||||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
}
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
|
||||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
|
||||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
|
||||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
|
||||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
|
||||||
driveClient := repository.NewGeminiDriveClient()
|
|
||||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
|
||||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
|
||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
|
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||||
|
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||||
|
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
|
channelRepository := repository.NewChannelRepository(db)
|
||||||
|
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||||
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
|
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||||
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||||
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
|
||||||
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
|
rpmCache := repository.NewRPMCache(redisClient)
|
||||||
|
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||||
|
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||||
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
|
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||||
|
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||||
|
driveClient := repository.NewGeminiDriveClient()
|
||||||
|
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||||
|
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
|
||||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||||
@ -173,24 +188,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
opsRepository := repository.NewOpsRepository(db)
|
opsRepository := repository.NewOpsRepository(db)
|
||||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
|
||||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
channelRepository := repository.NewChannelRepository(db)
|
|
||||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
|
||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
|
||||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
|
||||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
@ -261,7 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
|
|||||||
@ -1009,7 +1009,8 @@ type GatewaySchedulingConfig struct {
|
|||||||
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||||
|
|
||||||
// 负载计算
|
// 负载计算
|
||||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||||
|
LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"`
|
||||||
// 快照桶读取时的 MGET 分块大小
|
// 快照桶读取时的 MGET 分块大小
|
||||||
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
|
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
|
||||||
// 快照重建时的缓存写入分块大小
|
// 快照重建时的缓存写入分块大小
|
||||||
@ -1828,6 +1829,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||||
|
viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200)
|
||||||
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
|
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
|
||||||
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
|
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
|
||||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||||
@ -2634,6 +2636,9 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
|
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
|
||||||
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
|
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative")
|
||||||
|
}
|
||||||
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
|
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
|
||||||
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
|
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
|||||||
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||||
t.Fatalf("LoadBatchEnabled = false, want true")
|
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||||
}
|
}
|
||||||
|
if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 {
|
||||||
|
t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS)
|
||||||
|
}
|
||||||
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||||
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||||
}
|
}
|
||||||
@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||||
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gateway scheduling load batch cache ttl",
|
||||||
|
mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 },
|
||||||
|
wantErr: "gateway.scheduling.load_batch_cache_ttl_ms",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "gateway scheduling outbox poll",
|
name: "gateway scheduling outbox poll",
|
||||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
||||||
|
|||||||
@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
}
|
}
|
||||||
writerSizeBeforeForward := c.Writer.Size()
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||||
|
defer func() {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||||
|
}()
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
responseLatencyMs := forwardDurationMs
|
responseLatencyMs := forwardDurationMs
|
||||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
|
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
|||||||
@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
}
|
}
|
||||||
writerSizeBeforeForward := c.Writer.Size()
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||||
|
defer func() {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||||
|
}()
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
responseLatencyMs := forwardDurationMs
|
responseLatencyMs := forwardDurationMs
|
||||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
|
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
reqLog.Warn("openai.upstream_failover_switching",
|
reqLog.Warn("openai.upstream_failover_switching",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
if channelMappingMsg.Mapped {
|
if channelMappingMsg.Mapped {
|
||||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
||||||
}
|
}
|
||||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||||
|
defer func() {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||||
|
}()
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
responseLatencyMs := forwardDurationMs
|
responseLatencyMs := forwardDurationMs
|
||||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
|
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||||
|
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
|||||||
@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
|
|
||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||||
|
defer func() {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||||
|
}()
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
|
||||||
accountReleaseFunc()
|
|
||||||
}
|
|
||||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
responseLatencyMs := forwardDurationMs
|
responseLatencyMs := forwardDurationMs
|
||||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
@ -258,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
switchCount++
|
||||||
|
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
|||||||
@ -1258,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
|
|||||||
@ -531,6 +531,7 @@ type adminServiceImpl struct {
|
|||||||
defaultSubAssigner DefaultSubscriptionAssigner
|
defaultSubAssigner DefaultSubscriptionAssigner
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
privacyClientFactory PrivacyClientFactory
|
privacyClientFactory PrivacyClientFactory
|
||||||
|
runtimeBlocker AccountRuntimeBlocker
|
||||||
}
|
}
|
||||||
|
|
||||||
type userGroupRateBatchReader interface {
|
type userGroupRateBatchReader interface {
|
||||||
@ -556,6 +557,7 @@ func NewAdminService(
|
|||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
privacyClientFactory PrivacyClientFactory,
|
privacyClientFactory PrivacyClientFactory,
|
||||||
|
runtimeBlocker AccountRuntimeBlocker,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
@ -575,6 +577,7 @@ func NewAdminService(
|
|||||||
defaultSubAssigner: defaultSubAssigner,
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
privacyClientFactory: privacyClientFactory,
|
privacyClientFactory: privacyClientFactory,
|
||||||
|
runtimeBlocker: runtimeBlocker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
|||||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
|
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.runtimeBlocker != nil {
|
||||||
|
s.runtimeBlocker.ClearAccountSchedulingBlock(id)
|
||||||
|
}
|
||||||
return s.accountRepo.GetByID(ctx, id)
|
return s.accountRepo.GetByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
|||||||
TempUnschedulableReason: "missing refresh token",
|
TempUnschedulableReason: "missing refresh token",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := &adminServiceImpl{accountRepo: repo}
|
blocker := &runtimeBlockRecorder{}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker}
|
||||||
|
|
||||||
updated, err := svc.ClearAccountError(context.Background(), 31)
|
updated, err := svc.ClearAccountError(context.Background(), 31)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
|||||||
require.Nil(t, updated.RateLimitResetAt)
|
require.Nil(t, updated.RateLimitResetAt)
|
||||||
require.Nil(t, updated.TempUnschedulableUntil)
|
require.Nil(t, updated.TempUnschedulableUntil)
|
||||||
require.Empty(t, updated.TempUnschedulableReason)
|
require.Empty(t, updated.TempUnschedulableReason)
|
||||||
|
require.Equal(t, []int64{31}, blocker.clearedIDs)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,13 +3,17 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConcurrencyCache 定义并发控制的缓存接口
|
// ConcurrencyCache 定义并发控制的缓存接口
|
||||||
@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Default extra wait slots beyond concurrency limit
|
// 默认等待队列额外槽位
|
||||||
defaultExtraWaitSlots = 20
|
defaultExtraWaitSlots = 20
|
||||||
|
|
||||||
|
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
|
||||||
|
accountLoadBatchFetchTimeout = 3 * time.Second
|
||||||
|
maxAccountLoadBatchCacheEntries = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
// ConcurrencyService 管理账号和用户的并发限制。
|
||||||
type ConcurrencyService struct {
|
type ConcurrencyService struct {
|
||||||
cache ConcurrencyCache
|
cache ConcurrencyCache
|
||||||
|
|
||||||
|
accountLoadCacheTTL atomic.Int64
|
||||||
|
accountLoadCacheMu sync.RWMutex
|
||||||
|
accountLoadCache map[string]cachedAccountLoadBatch
|
||||||
|
accountLoadGroup singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConcurrencyService creates a new ConcurrencyService
|
type cachedAccountLoadBatch struct {
|
||||||
|
loadMap map[int64]*AccountLoadInfo
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConcurrencyService 创建并发控制服务。
|
||||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||||
return &ConcurrencyService{cache: cache}
|
svc := &ConcurrencyService{
|
||||||
|
cache: cache,
|
||||||
|
accountLoadCache: make(map[string]cachedAccountLoadBatch),
|
||||||
|
}
|
||||||
|
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
|
||||||
|
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.accountLoadCacheTTL.Store(int64(ttl))
|
||||||
|
if ttl <= 0 {
|
||||||
|
s.accountLoadCacheMu.Lock()
|
||||||
|
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||||
|
s.accountLoadCacheMu.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireResult represents the result of acquiring a concurrency slot
|
// AcquireResult represents the result of acquiring a concurrency slot
|
||||||
@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int {
|
|||||||
return userConcurrency + defaultExtraWaitSlots
|
return userConcurrency + defaultExtraWaitSlots
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
// GetAccountsLoadBatch 批量获取账号负载信息。
|
||||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
return s.getAccountsLoadBatch(ctx, accounts, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
|
||||||
|
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
return s.getAccountsLoadBatch(ctx, accounts, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
if len(accounts) == 0 {
|
||||||
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
if s.cache == nil {
|
if s.cache == nil {
|
||||||
return map[int64]*AccountLoadInfo{}, nil
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
}
|
}
|
||||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
|
||||||
|
ttl := time.Duration(s.accountLoadCacheTTL.Load())
|
||||||
|
if !allowCache || ttl <= 0 {
|
||||||
|
return s.fetchAccountsLoadBatch(ctx, accounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := accountLoadBatchCacheKey(accounts)
|
||||||
|
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
|
||||||
|
now := time.Now()
|
||||||
|
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
|
||||||
|
if fetchErr != nil {
|
||||||
|
return nil, fetchErr
|
||||||
|
}
|
||||||
|
cached := cloneAccountLoadMap(loadMap)
|
||||||
|
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
|
||||||
|
return cached, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
loadMap, _ := value.(map[int64]*AccountLoadInfo)
|
||||||
|
if loadMap == nil {
|
||||||
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
return loadMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return map[int64]*AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
baseCtx := context.Background()
|
||||||
|
if ctx != nil {
|
||||||
|
baseCtx = context.WithoutCancel(ctx)
|
||||||
|
}
|
||||||
|
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
|
||||||
|
s.accountLoadCacheMu.RLock()
|
||||||
|
cached, ok := s.accountLoadCache[key]
|
||||||
|
s.accountLoadCacheMu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if !now.Before(cached.expiresAt) {
|
||||||
|
s.accountLoadCacheMu.Lock()
|
||||||
|
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
|
||||||
|
delete(s.accountLoadCache, key)
|
||||||
|
}
|
||||||
|
s.accountLoadCacheMu.Unlock()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return cached.loadMap, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
|
||||||
|
s.accountLoadCacheMu.Lock()
|
||||||
|
if s.accountLoadCache == nil {
|
||||||
|
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||||
|
}
|
||||||
|
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||||
|
now := time.Now()
|
||||||
|
for cacheKey, cached := range s.accountLoadCache {
|
||||||
|
if !now.Before(cached.expiresAt) {
|
||||||
|
delete(s.accountLoadCache, cacheKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||||
|
for cacheKey := range s.accountLoadCache {
|
||||||
|
delete(s.accountLoadCache, cacheKey)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.accountLoadCache[key] = cachedAccountLoadBatch{
|
||||||
|
loadMap: loadMap,
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
s.accountLoadCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
|
||||||
|
hash := sha256.New()
|
||||||
|
var buf [16]byte
|
||||||
|
for _, account := range accounts {
|
||||||
|
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
|
||||||
|
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
|
||||||
|
_, _ = hash.Write(buf[:])
|
||||||
|
}
|
||||||
|
sum := hash.Sum(nil)
|
||||||
|
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
|
||||||
|
if len(loadMap) == 0 {
|
||||||
|
return map[int64]*AccountLoadInfo{}
|
||||||
|
}
|
||||||
|
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
|
||||||
|
for accountID, loadInfo := range loadMap {
|
||||||
|
if loadInfo == nil {
|
||||||
|
clone[accountID] = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied := *loadInfo
|
||||||
|
clone[accountID] = &copied
|
||||||
|
}
|
||||||
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersLoadBatch returns load info for multiple users.
|
// GetUsersLoadBatch returns load info for multiple users.
|
||||||
|
|||||||
@ -7,7 +7,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct {
|
|||||||
// 记录调用
|
// 记录调用
|
||||||
releasedAccountIDs []int64
|
releasedAccountIDs []int64
|
||||||
releasedRequestIDs []string
|
releasedRequestIDs []string
|
||||||
|
loadBatchCalls atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
||||||
@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
c.loadBatchCalls.Add(1)
|
||||||
return c.loadBatch, c.loadBatchErr
|
return c.loadBatch, c.loadBatchErr
|
||||||
}
|
}
|
||||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||||
@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
|
|||||||
require.Empty(t, result)
|
require.Empty(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) {
|
||||||
|
cache := &stubConcurrencyCacheForTest{
|
||||||
|
loadBatch: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewConcurrencyService(cache)
|
||||||
|
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||||
|
|
||||||
|
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||||
|
first, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, first[int64(1)].CurrentConcurrency)
|
||||||
|
|
||||||
|
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||||
|
second, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, second[int64(1)].CurrentConcurrency)
|
||||||
|
require.Equal(t, int64(1), cache.loadBatchCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) {
|
||||||
|
cache := &stubConcurrencyCacheForTest{
|
||||||
|
loadBatch: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewConcurrencyService(cache)
|
||||||
|
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||||
|
|
||||||
|
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||||
|
_, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||||
|
fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency)
|
||||||
|
require.Equal(t, int64(2), cache.loadBatchCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestIncrementWaitCount_Success(t *testing.T) {
|
func TestIncrementWaitCount_Success(t *testing.T) {
|
||||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||||
svc := NewConcurrencyService(cache)
|
svc := NewConcurrencyService(cache)
|
||||||
|
|||||||
@ -0,0 +1,169 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
openAIAccountStateUpdateTimeout = 5 * time.Second
|
||||||
|
openAIOAuth429FallbackCooldown = 5 * time.Second
|
||||||
|
openAIStopSchedulingBridgeCooldown = 2 * time.Minute
|
||||||
|
openAIOAuth429StormWindow = 10 * time.Second
|
||||||
|
openAIOAuth429StormThreshold = 20
|
||||||
|
openAIOAuth429StormMaxAccountSwitches = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
base := context.Background()
|
||||||
|
if ctx != nil {
|
||||||
|
base = context.WithoutCancel(ctx)
|
||||||
|
}
|
||||||
|
return context.WithTimeout(base, openAIAccountStateUpdateTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAIOAuthAccount(account *Account) bool {
|
||||||
|
return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAIAccount(account *Account) bool {
|
||||||
|
return account != nil && account.Platform == PlatformOpenAI
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
|
||||||
|
stateCtx, cancel := openAIAccountStateContext(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody)
|
||||||
|
}
|
||||||
|
if s == nil || account == nil || s.rateLimitService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
|
||||||
|
if shouldDisable {
|
||||||
|
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||||
|
}
|
||||||
|
return shouldDisable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||||
|
if s == nil || !isOpenAIOAuthAccount(account) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.recordOpenAIOAuth429()
|
||||||
|
|
||||||
|
cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown)
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) {
|
||||||
|
cooldownUntil = *resetAt
|
||||||
|
} else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil {
|
||||||
|
if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) {
|
||||||
|
cooldownUntil = resetAt
|
||||||
|
}
|
||||||
|
} else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 {
|
||||||
|
cooldownUntil = time.Now().Add(cooldown)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.BlockAccountScheduling(account, cooldownUntil, "429")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||||
|
if s == nil || !isOpenAIAccount(account) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
blockUntil := until
|
||||||
|
if blockUntil.IsZero() || !blockUntil.After(now) {
|
||||||
|
blockUntil = now.Add(openAIStopSchedulingBridgeCooldown)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||||
|
if !loaded {
|
||||||
|
actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil)
|
||||||
|
if !stored {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
current = actual
|
||||||
|
}
|
||||||
|
|
||||||
|
currentUntil, ok := current.(time.Time)
|
||||||
|
if !ok || currentUntil.IsZero() {
|
||||||
|
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if currentUntil.After(blockUntil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) {
|
||||||
|
if s == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.openaiAccountRuntimeBlockUntil.Delete(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool {
|
||||||
|
if s == nil || !isOpenAIAccount(account) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cooldownUntil, ok := value.(time.Time)
|
||||||
|
if !ok || cooldownUntil.IsZero() {
|
||||||
|
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if time.Now().Before(cooldownUntil) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) recordOpenAIOAuth429() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||||
|
if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||||
|
if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) {
|
||||||
|
s.openaiOAuth429WindowCount.Store(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.openaiOAuth429WindowCount.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||||
|
if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool {
|
||||||
|
if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !isOpenAIOAuthAccount(account) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.isOpenAIOAuth429Storm()
|
||||||
|
}
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil)
|
||||||
|
apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil)
|
||||||
|
|
||||||
|
require.False(t, shouldDisable)
|
||||||
|
require.False(t, apiKeyShouldDisable)
|
||||||
|
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||||
|
|
||||||
|
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
|
||||||
|
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||||
|
|
||||||
|
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) {
|
||||||
|
gateway := &OpenAIGatewayService{}
|
||||||
|
repo := &rateLimitAccountRepoStub{}
|
||||||
|
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
rateLimitService.SetAccountRuntimeBlocker(gateway)
|
||||||
|
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||||
|
|
||||||
|
shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden"))
|
||||||
|
|
||||||
|
require.True(t, shouldDisable)
|
||||||
|
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
longUntil := time.Now().Add(10 * time.Minute)
|
||||||
|
|
||||||
|
svc.BlockAccountScheduling(account, longUntil, "oauth_401")
|
||||||
|
svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||||
|
|
||||||
|
value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||||
|
require.True(t, ok)
|
||||||
|
actualUntil, ok := value.(time.Time)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.WithinDuration(t, longUntil, actualUntil, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
|
||||||
|
svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429")
|
||||||
|
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
|
||||||
|
svc.ClearAccountSchedulingBlock(account.ID)
|
||||||
|
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||||
|
|
||||||
|
for i := 0; i < openAIOAuth429StormThreshold; i++ {
|
||||||
|
svc.recordOpenAIOAuth429()
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||||
|
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1))
|
||||||
|
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1))
|
||||||
|
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0))
|
||||||
|
}
|
||||||
@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct {
|
|||||||
loadSkewMilliTotal atomic.Int64
|
loadSkewMilliTotal atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIAccountLoadPlan struct {
|
||||||
|
allCandidates []openAIAccountCandidateScore
|
||||||
|
candidates []openAIAccountCandidateScore
|
||||||
|
staleSnapshotCompactRetry []openAIAccountCandidateScore
|
||||||
|
selectionOrder []openAIAccountCandidateScore
|
||||||
|
candidateCount int
|
||||||
|
topK int
|
||||||
|
loadSkew float64
|
||||||
|
}
|
||||||
|
|
||||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return
|
return
|
||||||
@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if acquireErr == nil && result.Acquired {
|
if acquireErr == nil && result != nil && result.Acquired {
|
||||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
@ -586,6 +596,231 @@ func buildOpenAIWeightedSelectionOrder(
|
|||||||
return order
|
return order
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan(
|
||||||
|
req OpenAIAccountScheduleRequest,
|
||||||
|
filtered []*Account,
|
||||||
|
loadMap map[int64]*AccountLoadInfo,
|
||||||
|
) openAIAccountLoadPlan {
|
||||||
|
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||||
|
for _, account := range filtered {
|
||||||
|
loadInfo := loadMap[account.ID]
|
||||||
|
if loadInfo == nil {
|
||||||
|
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||||
|
}
|
||||||
|
errorRate, ttft, hasTTFT := 0.0, 0.0, false
|
||||||
|
if s.stats != nil {
|
||||||
|
errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID)
|
||||||
|
}
|
||||||
|
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||||
|
account: account,
|
||||||
|
loadInfo: loadInfo,
|
||||||
|
errorRate: errorRate,
|
||||||
|
ttft: ttft,
|
||||||
|
hasTTFT: hasTTFT,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates := allCandidates
|
||||||
|
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||||
|
if req.RequireCompact {
|
||||||
|
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||||
|
for _, candidate := range allCandidates {
|
||||||
|
if openAICompactSupportTier(candidate.account) == 0 {
|
||||||
|
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
plan := openAIAccountLoadPlan{
|
||||||
|
allCandidates: allCandidates,
|
||||||
|
candidates: candidates,
|
||||||
|
staleSnapshotCompactRetry: staleSnapshotCompactRetry,
|
||||||
|
candidateCount: len(candidates),
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||||
|
return plan
|
||||||
|
}
|
||||||
|
|
||||||
|
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||||
|
maxWaiting := 1
|
||||||
|
loadRateSum := 0.0
|
||||||
|
loadRateSumSquares := 0.0
|
||||||
|
minTTFT, maxTTFT := 0.0, 0.0
|
||||||
|
hasTTFTSample := false
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate.account.Priority < minPriority {
|
||||||
|
minPriority = candidate.account.Priority
|
||||||
|
}
|
||||||
|
if candidate.account.Priority > maxPriority {
|
||||||
|
maxPriority = candidate.account.Priority
|
||||||
|
}
|
||||||
|
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||||
|
maxWaiting = candidate.loadInfo.WaitingCount
|
||||||
|
}
|
||||||
|
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||||
|
if !hasTTFTSample {
|
||||||
|
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||||
|
hasTTFTSample = true
|
||||||
|
} else {
|
||||||
|
if candidate.ttft < minTTFT {
|
||||||
|
minTTFT = candidate.ttft
|
||||||
|
}
|
||||||
|
if candidate.ttft > maxTTFT {
|
||||||
|
maxTTFT = candidate.ttft
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||||
|
loadRateSum += loadRate
|
||||||
|
loadRateSumSquares += loadRate * loadRate
|
||||||
|
}
|
||||||
|
plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||||
|
|
||||||
|
weights := s.service.openAIWSSchedulerWeights()
|
||||||
|
for i := range candidates {
|
||||||
|
item := &candidates[i]
|
||||||
|
priorityFactor := 1.0
|
||||||
|
if maxPriority > minPriority {
|
||||||
|
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||||
|
}
|
||||||
|
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||||
|
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||||
|
errorFactor := 1 - clamp01(item.errorRate)
|
||||||
|
ttftFactor := 0.5
|
||||||
|
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||||
|
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||||
|
}
|
||||||
|
|
||||||
|
item.score = weights.Priority*priorityFactor +
|
||||||
|
weights.Load*loadFactor +
|
||||||
|
weights.Queue*queueFactor +
|
||||||
|
weights.ErrorRate*errorFactor +
|
||||||
|
weights.TTFT*ttftFactor
|
||||||
|
}
|
||||||
|
plan.candidates = candidates
|
||||||
|
|
||||||
|
plan.topK = s.service.openAIWSLBTopK()
|
||||||
|
if plan.topK > len(candidates) {
|
||||||
|
plan.topK = len(candidates)
|
||||||
|
}
|
||||||
|
if plan.topK <= 0 {
|
||||||
|
plan.topK = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||||
|
return plan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder(
|
||||||
|
req OpenAIAccountScheduleRequest,
|
||||||
|
plan openAIAccountLoadPlan,
|
||||||
|
) []openAIAccountCandidateScore {
|
||||||
|
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||||
|
if len(pool) == 0 || plan.topK <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
groupTopK := plan.topK
|
||||||
|
if groupTopK > len(pool) {
|
||||||
|
groupTopK = len(pool)
|
||||||
|
}
|
||||||
|
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||||
|
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RequireCompact {
|
||||||
|
supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||||
|
unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||||
|
for _, candidate := range plan.candidates {
|
||||||
|
switch openAICompactSupportTier(candidate.account) {
|
||||||
|
case 2:
|
||||||
|
supported = append(supported, candidate)
|
||||||
|
case 1:
|
||||||
|
unknown = append(unknown, candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates))
|
||||||
|
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||||
|
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||||
|
if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||||
|
selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...)
|
||||||
|
}
|
||||||
|
return selectionOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildSelectionOrder(plan.candidates)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||||
|
if len(pool) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||||
|
sort.SliceStable(ordered, func(i, j int) bool {
|
||||||
|
a, b := ordered[i], ordered[j]
|
||||||
|
if a.account.Priority != b.account.Priority {
|
||||||
|
return a.account.Priority < b.account.Priority
|
||||||
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||||
|
}
|
||||||
|
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||||
|
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
|
||||||
|
ctx context.Context,
|
||||||
|
req OpenAIAccountScheduleRequest,
|
||||||
|
selectionOrder []openAIAccountCandidateScore,
|
||||||
|
) (*AccountSelectionResult, bool, error) {
|
||||||
|
compactBlocked := false
|
||||||
|
for i := 0; i < len(selectionOrder); i++ {
|
||||||
|
candidate := selectionOrder[i]
|
||||||
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||||
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||||
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||||
|
compactBlocked = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
|
if acquireErr != nil {
|
||||||
|
return nil, compactBlocked, acquireErr
|
||||||
|
}
|
||||||
|
if result != nil && result.Acquired {
|
||||||
|
if req.SessionHash != "" {
|
||||||
|
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: fresh,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, compactBlocked, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, compactBlocked, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req OpenAIAccountScheduleRequest,
|
req OpenAIAccountScheduleRequest,
|
||||||
@ -616,8 +851,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
|
||||||
|
s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set")
|
||||||
_ = s.service.accountRepo.SetError(ctx, account.ID,
|
_ = s.service.accountRepo.SetError(ctx, account.ID,
|
||||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
continue
|
continue
|
||||||
@ -645,208 +884,46 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap)
|
||||||
for _, account := range filtered {
|
candidateCount := plan.candidateCount
|
||||||
loadInfo := loadMap[account.ID]
|
topK := plan.topK
|
||||||
if loadInfo == nil {
|
loadSkew := plan.loadSkew
|
||||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
selectionOrder := plan.selectionOrder
|
||||||
}
|
if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 {
|
||||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
|
||||||
account: account,
|
|
||||||
loadInfo: loadInfo,
|
|
||||||
errorRate: errorRate,
|
|
||||||
ttft: ttft,
|
|
||||||
hasTTFT: hasTTFT,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
if req.RequireCompact && len(selectionOrder) == 0 && s.service.schedulerSnapshot == nil {
|
||||||
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
|
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
||||||
// 时作为最后兜底(snapshot 可能已陈旧)。
|
|
||||||
candidates := allCandidates
|
|
||||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
|
||||||
if req.RequireCompact {
|
|
||||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
|
||||||
for _, candidate := range allCandidates {
|
|
||||||
if openAICompactSupportTier(candidate.account) == 0 {
|
|
||||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
candidates = append(candidates, candidate)
|
|
||||||
}
|
|
||||||
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
|
|
||||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
candidateCount := len(candidates)
|
|
||||||
loadSkew := 0.0
|
|
||||||
if len(candidates) > 0 {
|
|
||||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
|
||||||
maxWaiting := 1
|
|
||||||
loadRateSum := 0.0
|
|
||||||
loadRateSumSquares := 0.0
|
|
||||||
minTTFT, maxTTFT := 0.0, 0.0
|
|
||||||
hasTTFTSample := false
|
|
||||||
for _, candidate := range candidates {
|
|
||||||
if candidate.account.Priority < minPriority {
|
|
||||||
minPriority = candidate.account.Priority
|
|
||||||
}
|
|
||||||
if candidate.account.Priority > maxPriority {
|
|
||||||
maxPriority = candidate.account.Priority
|
|
||||||
}
|
|
||||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
|
||||||
maxWaiting = candidate.loadInfo.WaitingCount
|
|
||||||
}
|
|
||||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
|
||||||
if !hasTTFTSample {
|
|
||||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
|
||||||
hasTTFTSample = true
|
|
||||||
} else {
|
|
||||||
if candidate.ttft < minTTFT {
|
|
||||||
minTTFT = candidate.ttft
|
|
||||||
}
|
|
||||||
if candidate.ttft > maxTTFT {
|
|
||||||
maxTTFT = candidate.ttft
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
|
||||||
loadRateSum += loadRate
|
|
||||||
loadRateSumSquares += loadRate * loadRate
|
|
||||||
}
|
|
||||||
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
|
||||||
|
|
||||||
weights := s.service.openAIWSSchedulerWeights()
|
|
||||||
for i := range candidates {
|
|
||||||
item := &candidates[i]
|
|
||||||
priorityFactor := 1.0
|
|
||||||
if maxPriority > minPriority {
|
|
||||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
|
||||||
}
|
|
||||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
|
||||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
|
||||||
errorFactor := 1 - clamp01(item.errorRate)
|
|
||||||
ttftFactor := 0.5
|
|
||||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
|
||||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
|
||||||
}
|
|
||||||
|
|
||||||
item.score = weights.Priority*priorityFactor +
|
|
||||||
weights.Load*loadFactor +
|
|
||||||
weights.Queue*queueFactor +
|
|
||||||
weights.ErrorRate*errorFactor +
|
|
||||||
weights.TTFT*ttftFactor
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
topK := 0
|
|
||||||
if len(candidates) > 0 {
|
|
||||||
topK = s.service.openAIWSLBTopK()
|
|
||||||
if topK > len(candidates) {
|
|
||||||
topK = len(candidates)
|
|
||||||
}
|
|
||||||
if topK <= 0 {
|
|
||||||
topK = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
|
||||||
if len(pool) == 0 || topK <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
groupTopK := topK
|
|
||||||
if groupTopK > len(pool) {
|
|
||||||
groupTopK = len(pool)
|
|
||||||
}
|
|
||||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
|
||||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
|
||||||
}
|
|
||||||
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
|
||||||
if len(pool) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
|
||||||
sort.SliceStable(ordered, func(i, j int) bool {
|
|
||||||
a, b := ordered[i], ordered[j]
|
|
||||||
if a.account.Priority != b.account.Priority {
|
|
||||||
return a.account.Priority < b.account.Priority
|
|
||||||
}
|
|
||||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
|
||||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
|
||||||
}
|
|
||||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
|
||||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
|
||||||
return true
|
|
||||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
|
||||||
return false
|
|
||||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return ordered
|
|
||||||
}
|
|
||||||
|
|
||||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
|
||||||
if req.RequireCompact {
|
|
||||||
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
|
|
||||||
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
|
|
||||||
for _, candidate := range candidates {
|
|
||||||
switch openAICompactSupportTier(candidate.account) {
|
|
||||||
case 2:
|
|
||||||
supported = append(supported, candidate)
|
|
||||||
case 1:
|
|
||||||
unknown = append(unknown, candidate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
|
|
||||||
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
|
||||||
}
|
|
||||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
|
||||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
|
||||||
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
|
||||||
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
selectionOrder = buildSelectionOrder(candidates)
|
|
||||||
}
|
}
|
||||||
if len(selectionOrder) == 0 {
|
if len(selectionOrder) == 0 {
|
||||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
|
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
compactBlocked := false
|
result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder)
|
||||||
for i := 0; i < len(selectionOrder); i++ {
|
if acquireErr != nil {
|
||||||
candidate := selectionOrder[i]
|
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
}
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if result != nil {
|
||||||
continue
|
return result, candidateCount, topK, loadSkew, nil
|
||||||
}
|
}
|
||||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if s.service.concurrencyService != nil {
|
||||||
continue
|
if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil {
|
||||||
}
|
freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap)
|
||||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
if len(freshPlan.selectionOrder) > 0 {
|
||||||
compactBlocked = true
|
freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder)
|
||||||
continue
|
if freshAcquireErr != nil {
|
||||||
}
|
return nil, candidateCount, topK, loadSkew, freshAcquireErr
|
||||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
}
|
||||||
if acquireErr != nil {
|
if freshResult != nil {
|
||||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil
|
||||||
}
|
}
|
||||||
if result != nil && result.Acquired {
|
compactBlocked = compactBlocked || freshCompactBlocked
|
||||||
if req.SessionHash != "" {
|
selectionOrder = freshPlan.selectionOrder
|
||||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
candidateCount = freshPlan.candidateCount
|
||||||
|
topK = freshPlan.topK
|
||||||
|
loadSkew = freshPlan.loadSkew
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
|
||||||
Account: fresh,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, candidateCount, topK, loadSkew, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -893,6 +970,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: upstreamDetail,
|
Detail: upstreamDetail,
|
||||||
})
|
})
|
||||||
if s.rateLimitService != nil {
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
}
|
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
|
|||||||
@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: upstreamDetail,
|
Detail: upstreamDetail,
|
||||||
})
|
})
|
||||||
if s.rateLimitService != nil {
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
}
|
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
|
|||||||
@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: upstreamDetail,
|
Detail: upstreamDetail,
|
||||||
})
|
})
|
||||||
if s.rateLimitService != nil {
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
}
|
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
|
|||||||
@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
|||||||
Message: upstreamMsg,
|
Message: upstreamMsg,
|
||||||
Detail: upstreamDetail,
|
Detail: upstreamDetail,
|
||||||
})
|
})
|
||||||
if s.rateLimitService != nil {
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
}
|
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
ResponseBody: respBody,
|
ResponseBody: respBody,
|
||||||
|
|||||||
@ -354,6 +354,9 @@ type OpenAIGatewayService struct {
|
|||||||
openaiAccountStats *openAIAccountRuntimeStats
|
openaiAccountStats *openAIAccountRuntimeStats
|
||||||
|
|
||||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
|
openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time
|
||||||
|
openaiOAuth429WindowStartUnixNano atomic.Int64
|
||||||
|
openaiOAuth429WindowCount atomic.Int64
|
||||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||||
codexSnapshotThrottle *accountWriteThrottle
|
codexSnapshotThrottle *accountWriteThrottle
|
||||||
@ -417,6 +420,12 @@ func NewOpenAIGatewayService(
|
|||||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||||
}
|
}
|
||||||
|
if rateLimitService != nil {
|
||||||
|
rateLimitService.SetAccountRuntimeBlocker(svc)
|
||||||
|
}
|
||||||
|
if openAITokenProvider != nil {
|
||||||
|
openAITokenProvider.SetAccountRuntimeBlocker(svc)
|
||||||
|
}
|
||||||
svc.logOpenAIWSModeBootstrap()
|
svc.logOpenAIWSModeBootstrap()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@ -1381,13 +1390,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
|||||||
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hydrated, err := s.hydrateSelectedAccount(ctx, selected)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// 4. 设置粘性会话绑定
|
// 4. 设置粘性会话绑定
|
||||||
// Set sticky session binding
|
// Set sticky session binding
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.hydrateSelectedAccount(ctx, selected)
|
return hydrated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||||
@ -1430,6 +1444,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
|||||||
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||||
if account == nil {
|
if account == nil {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
@ -1575,8 +1593,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result != nil && result.Acquired {
|
||||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||||
}
|
}
|
||||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||||
@ -1627,13 +1645,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||||
if account == nil {
|
if account == nil {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
|
} else if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
|
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
} else {
|
} else {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result != nil && result.Acquired {
|
||||||
|
selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||||
|
if selectErr != nil {
|
||||||
|
return nil, selectErr
|
||||||
|
}
|
||||||
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
||||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
return selection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
@ -1665,6 +1689,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if !acc.IsSchedulable() {
|
if !acc.IsSchedulable() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if s.isOpenAIAccountRuntimeBlocked(acc) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1687,6 +1714,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) {
|
||||||
|
var available []accountWithLoad
|
||||||
|
for _, acc := range candidates {
|
||||||
|
loadInfo := loadMap[acc.ID]
|
||||||
|
if loadInfo == nil {
|
||||||
|
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||||
|
}
|
||||||
|
if loadInfo.LoadRate < 100 {
|
||||||
|
available = append(available, accountWithLoad{
|
||||||
|
account: acc,
|
||||||
|
loadInfo: loadInfo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(available) == 0 {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.SliceStable(available, func(i, j int) bool {
|
||||||
|
a, b := available[i], available[j]
|
||||||
|
if a.account.Priority != b.account.Priority {
|
||||||
|
return a.account.Priority < b.account.Priority
|
||||||
|
}
|
||||||
|
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||||
|
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||||
|
return true
|
||||||
|
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
shuffleWithinSortGroups(available)
|
||||||
|
|
||||||
|
selectionOrder := make([]accountWithLoad, 0, len(available))
|
||||||
|
if requireCompact {
|
||||||
|
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
||||||
|
for _, item := range available {
|
||||||
|
if openAICompactSupportTier(item.account) == tier {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
selectionOrder = appendTier(selectionOrder, 2)
|
||||||
|
selectionOrder = appendTier(selectionOrder, 1)
|
||||||
|
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
||||||
|
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
||||||
|
selectionOrder = appendTier(selectionOrder, 0)
|
||||||
|
} else {
|
||||||
|
selectionOrder = append(selectionOrder, available...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range selectionOrder {
|
||||||
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
||||||
|
if fresh == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
||||||
|
if fresh == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
|
if err == nil && result != nil && result.Acquired {
|
||||||
|
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||||
|
if selectErr != nil {
|
||||||
|
return nil, true, selectErr
|
||||||
|
}
|
||||||
|
if sessionHash != "" {
|
||||||
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||||
|
}
|
||||||
|
return selection, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ordered := append([]*Account(nil), candidates...)
|
ordered := append([]*Account(nil), candidates...)
|
||||||
@ -1707,87 +1820,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result != nil && result.Acquired {
|
||||||
|
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||||
|
if selectErr != nil {
|
||||||
|
return nil, selectErr
|
||||||
|
}
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
return selection, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var available []accountWithLoad
|
if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil {
|
||||||
for _, acc := range candidates {
|
return nil, selectErr
|
||||||
loadInfo := loadMap[acc.ID]
|
} else if selection != nil {
|
||||||
if loadInfo == nil {
|
return selection, nil
|
||||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
} else if attempted {
|
||||||
}
|
if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil {
|
||||||
if loadInfo.LoadRate < 100 {
|
if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil {
|
||||||
available = append(available, accountWithLoad{
|
return nil, selectErr
|
||||||
account: acc,
|
} else if selection != nil {
|
||||||
loadInfo: loadInfo,
|
return selection, nil
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(available) > 0 {
|
|
||||||
sort.SliceStable(available, func(i, j int) bool {
|
|
||||||
a, b := available[i], available[j]
|
|
||||||
if a.account.Priority != b.account.Priority {
|
|
||||||
return a.account.Priority < b.account.Priority
|
|
||||||
}
|
|
||||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
|
||||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
|
||||||
}
|
|
||||||
switch {
|
|
||||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
|
||||||
return true
|
|
||||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
|
||||||
return false
|
|
||||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
shuffleWithinSortGroups(available)
|
|
||||||
|
|
||||||
selectionOrder := make([]accountWithLoad, 0, len(available))
|
|
||||||
if requireCompact {
|
|
||||||
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
|
||||||
for _, item := range available {
|
|
||||||
if openAICompactSupportTier(item.account) == tier {
|
|
||||||
out = append(out, item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
selectionOrder = appendTier(selectionOrder, 2)
|
|
||||||
selectionOrder = appendTier(selectionOrder, 1)
|
|
||||||
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
|
||||||
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
|
||||||
selectionOrder = appendTier(selectionOrder, 0)
|
|
||||||
} else {
|
|
||||||
selectionOrder = append(selectionOrder, available...)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range selectionOrder {
|
|
||||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
|
||||||
if fresh == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
|
||||||
if fresh == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
|
||||||
if err == nil && result.Acquired {
|
|
||||||
if sessionHash != "" {
|
|
||||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
|
||||||
}
|
|
||||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1868,6 +1922,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
|||||||
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
|
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return fresh
|
return fresh
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1889,6 +1946,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
|||||||
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
|
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return latest
|
return latest
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1935,6 +1995,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) {
|
||||||
|
selection, err := s.newSelectionResult(ctx, account, true, release, nil)
|
||||||
|
if err != nil && release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
|
return selection, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
return s.cfg.Gateway.Scheduling
|
return s.cfg.Gateway.Scheduling
|
||||||
@ -1996,7 +2064,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
|
|||||||
|
|
||||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward forwards request to OpenAI API
|
// Forward forwards request to OpenAI API
|
||||||
@ -3278,9 +3346,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
|
|||||||
}
|
}
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||||
if s.rateLimitService != nil {
|
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
||||||
}
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
@ -3321,12 +3387,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
|||||||
}
|
}
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||||
if s.rateLimitService != nil {
|
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
|
||||||
// Passthrough mode preserves the raw upstream error response, but runtime
|
// 避免粘性路由继续复用刚被限流的账号。
|
||||||
// account state still needs to be updated so sticky routing can stop
|
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
// reusing a freshly rate-limited account.
|
|
||||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
||||||
}
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
Platform: account.Platform,
|
Platform: account.Platform,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
@ -4075,10 +4138,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle upstream error (mark account status)
|
// Handle upstream error (mark account status)
|
||||||
shouldDisable := false
|
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
if s.rateLimitService != nil {
|
|
||||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
||||||
}
|
|
||||||
kind := "http_error"
|
kind := "http_error"
|
||||||
if shouldDisable {
|
if shouldDisable {
|
||||||
kind = "failover"
|
kind = "failover"
|
||||||
@ -4210,12 +4270,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Track rate limits and decide whether to trigger secondary failover.
|
// Track rate limits and decide whether to trigger secondary failover.
|
||||||
shouldDisable := false
|
shouldDisable := s.handleOpenAIAccountUpstreamError(
|
||||||
if s.rateLimitService != nil {
|
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||||
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
)
|
||||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
kind := "http_error"
|
kind := "http_error"
|
||||||
if shouldDisable {
|
if shouldDisable {
|
||||||
kind = "failover"
|
kind = "failover"
|
||||||
|
|||||||
@ -80,6 +80,7 @@ type OpenAITokenProvider struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache OpenAITokenCache
|
tokenCache OpenAITokenCache
|
||||||
openAIOAuthService *OpenAIOAuthService
|
openAIOAuthService *OpenAIOAuthService
|
||||||
|
runtimeBlocker AccountRuntimeBlocker
|
||||||
metrics *openAITokenRuntimeMetricsStore
|
metrics *openAITokenRuntimeMetricsStore
|
||||||
refreshAPI *OAuthRefreshAPI
|
refreshAPI *OAuthRefreshAPI
|
||||||
executor OAuthRefreshExecutor
|
executor OAuthRefreshExecutor
|
||||||
@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
|||||||
p.refreshPolicy = policy
|
p.refreshPolicy = policy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||||
|
p.runtimeBlocker = blocker
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return OpenAITokenRuntimeMetrics{}
|
return OpenAITokenRuntimeMetrics{}
|
||||||
@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account
|
|||||||
if p == nil || p.accountRepo == nil || account == nil {
|
if p == nil || p.accountRepo == nil || account == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if p.runtimeBlocker != nil {
|
||||||
|
p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token")
|
||||||
|
}
|
||||||
bgCtx := context.Background()
|
bgCtx := context.Background()
|
||||||
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
|
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
|
||||||
slog.Warn("openai_token_provider.set_error_failed",
|
slog.Warn("openai_token_provider.set_error_failed",
|
||||||
|
|||||||
@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
|||||||
cache.getErr = errors.New("simulated cache miss")
|
cache.getErr = errors.New("simulated cache miss")
|
||||||
|
|
||||||
provider := NewOpenAITokenProvider(repo, cache, nil)
|
provider := NewOpenAITokenProvider(repo, cache, nil)
|
||||||
|
blocker := &runtimeBlockRecorder{}
|
||||||
|
provider.SetAccountRuntimeBlocker(blocker)
|
||||||
|
|
||||||
token, err := provider.GetAccessToken(context.Background(), account)
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
|||||||
|
|
||||||
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
|
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
|
||||||
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
|
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
|
||||||
|
require.Len(t, blocker.accounts, 1)
|
||||||
|
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||||
|
require.Equal(t, "missing_refresh_token", blocker.reasons[0])
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex
|
|||||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||||
|
|||||||
@ -28,10 +28,16 @@ type RateLimitService struct {
|
|||||||
openAI403CounterCache OpenAI403CounterCache
|
openAI403CounterCache OpenAI403CounterCache
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
tokenCacheInvalidator TokenCacheInvalidator
|
tokenCacheInvalidator TokenCacheInvalidator
|
||||||
|
runtimeBlocker AccountRuntimeBlocker
|
||||||
usageCacheMu sync.RWMutex
|
usageCacheMu sync.RWMutex
|
||||||
usageCache map[int64]*geminiUsageCacheEntry
|
usageCache map[int64]*geminiUsageCacheEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccountRuntimeBlocker interface {
|
||||||
|
BlockAccountScheduling(account *Account, until time.Time, reason string)
|
||||||
|
ClearAccountSchedulingBlock(accountID int64)
|
||||||
|
}
|
||||||
|
|
||||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||||
type SuccessfulTestRecoveryResult struct {
|
type SuccessfulTestRecoveryResult struct {
|
||||||
ClearedError bool
|
ClearedError bool
|
||||||
@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
|
|||||||
s.tokenCacheInvalidator = invalidator
|
s.tokenCacheInvalidator = invalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||||
|
s.runtimeBlocker = blocker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||||
|
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||||
|
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorPolicyResult 表示错误策略检查的结果
|
// ErrorPolicyResult 表示错误策略检查的结果
|
||||||
type ErrorPolicyResult int
|
type ErrorPolicyResult int
|
||||||
|
|
||||||
@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
cooldownMinutes = 10
|
cooldownMinutes = 10
|
||||||
}
|
}
|
||||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "oauth_401")
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
|
||||||
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
|||||||
|
|
||||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||||
|
s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error")
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account
|
|||||||
|
|
||||||
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
|
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
|
||||||
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
|
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp")
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||||
s.handleAuthError(ctx, account, msg)
|
s.handleAuthError(ctx, account, msg)
|
||||||
@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
|
|||||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||||
|
s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code")
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||||
return
|
return
|
||||||
@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
|
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
|
||||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||||
|
s.notifyAccountSchedulingBlocked(account, *resetAt, "429")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
|
|
||||||
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||||
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||||
|
s.notifyAccountSchedulingBlocked(account, result.resetAt, "429")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||||
resetTime := time.Unix(*resetAt, 0)
|
resetTime := time.Unix(*resetAt, 0)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
// 尝试解析 Gemini 格式(用于其他平台)
|
// 尝试解析 Gemini 格式(用于其他平台)
|
||||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||||
resetTime := time.Unix(*resetAt, 0)
|
resetTime := time.Unix(*resetAt, 0)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
|||||||
resetAt := time.Unix(ts, 0)
|
resetAt := time.Unix(ts, 0)
|
||||||
|
|
||||||
// 标记限流状态
|
// 标记限流状态
|
||||||
|
s.notifyAccountSchedulingBlocked(account, resetAt, "429")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun
|
|||||||
|
|
||||||
resetAt := time.Now().Add(cooldown)
|
resetAt := time.Now().Add(cooldown)
|
||||||
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
|
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
|
||||||
|
s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback")
|
||||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "529")
|
||||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||||
return
|
return
|
||||||
@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.ResetOpenAI403Counter(ctx, accountID)
|
s.ResetOpenAI403Counter(ctx, accountID)
|
||||||
|
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
|
|||||||
}
|
}
|
||||||
if result.ClearedError || result.ClearedRateLimit {
|
if result.ClearedError || result.ClearedRateLimit {
|
||||||
s.ResetOpenAI403Counter(ctx, accountID)
|
s.ResetOpenAI403Counter(ctx, accountID)
|
||||||
|
if result.ClearedError && !result.ClearedRateLimit {
|
||||||
|
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
|||||||
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
|
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
|
||||||
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
|
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
|
||||||
}
|
}
|
||||||
|
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
|||||||
reason = strings.TrimSpace(state.ErrorMessage)
|
reason = strings.TrimSpace(state.ErrorMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable")
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
|||||||
reason = state.ErrorMessage
|
reason = state.ErrorMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable")
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
|||||||
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
|
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
|
||||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||||
|
|
||||||
|
s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error")
|
||||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -6,16 +6,36 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type runtimeBlockRecorder struct {
|
||||||
|
accounts []*Account
|
||||||
|
until []time.Time
|
||||||
|
reasons []string
|
||||||
|
clearedIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||||
|
r.accounts = append(r.accounts, account)
|
||||||
|
r.until = append(r.until, until)
|
||||||
|
r.reasons = append(r.reasons, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) {
|
||||||
|
r.clearedIDs = append(r.clearedIDs, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
|
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
|
||||||
repo := &rateLimitAccountRepoStub{}
|
repo := &rateLimitAccountRepoStub{}
|
||||||
counter := &openAI403CounterCacheStub{counts: []int64{1}}
|
counter := &openAI403CounterCacheStub{counts: []int64{1}}
|
||||||
|
blocker := &runtimeBlockRecorder{}
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
service.SetOpenAI403CounterCache(counter)
|
service.SetOpenAI403CounterCache(counter)
|
||||||
|
service.SetAccountRuntimeBlocker(blocker)
|
||||||
account := &Account{
|
account := &Account{
|
||||||
ID: 301,
|
ID: 301,
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable
|
|||||||
require.Equal(t, 1, repo.tempCalls)
|
require.Equal(t, 1, repo.tempCalls)
|
||||||
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
|
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
|
||||||
require.Contains(t, repo.lastTempReason, "(1/3)")
|
require.Contains(t, repo.lastTempReason, "(1/3)")
|
||||||
|
require.Len(t, blocker.accounts, 1)
|
||||||
|
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||||
|
require.Equal(t, "openai_403_temp", blocker.reasons[0])
|
||||||
|
require.True(t, blocker.until[0].After(time.Now()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
|
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
|
||||||
|
|||||||
@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
cache := &tempUnschedCacheRecorder{}
|
cache := &tempUnschedCacheRecorder{}
|
||||||
|
blocker := &runtimeBlockRecorder{}
|
||||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||||
|
svc.SetAccountRuntimeBlocker(blocker)
|
||||||
|
|
||||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
|||||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||||
|
require.Equal(t, []int64{42}, blocker.clearedIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||||
|
|||||||
@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) {
|
||||||
|
cache := &snapshotHydrationCache{
|
||||||
|
accounts: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil)
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
|
}
|
||||||
|
releaseCalls := 0
|
||||||
|
|
||||||
|
selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() {
|
||||||
|
releaseCalls++
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected hydration error")
|
||||||
|
}
|
||||||
|
if selection != nil {
|
||||||
|
t.Fatalf("expected nil selection on hydration error")
|
||||||
|
}
|
||||||
|
if releaseCalls != 1 {
|
||||||
|
t.Fatalf("expected release to be called once, got %d", releaseCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
|
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
|
||||||
cache := &snapshotHydrationCache{
|
cache := &snapshotHydrationCache{
|
||||||
snapshot: []*Account{
|
snapshot: []*Account{
|
||||||
|
|||||||
@ -27,6 +27,7 @@ type TokenRefreshService struct {
|
|||||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||||
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||||
|
runtimeBlocker AccountRuntimeBlocker
|
||||||
|
|
||||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||||
privacyClientFactory PrivacyClientFactory
|
privacyClientFactory PrivacyClientFactory
|
||||||
@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
|||||||
s.refreshPolicy = policy
|
s.refreshPolicy = policy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||||
|
s.runtimeBlocker = blocker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||||
|
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||||
|
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// Start 启动后台刷新服务
|
// Start 启动后台刷新服务
|
||||||
func (s *TokenRefreshService) Start() {
|
func (s *TokenRefreshService) Start() {
|
||||||
if !s.cfg.Enabled {
|
if !s.cfg.Enabled {
|
||||||
@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
|
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
|
||||||
if isNonRetryableRefreshError(err) {
|
if isNonRetryableRefreshError(err) {
|
||||||
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable")
|
||||||
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
||||||
slog.Error("token_refresh.set_error_status_failed",
|
slog.Error("token_refresh.set_error_status_failed",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
|||||||
// 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试)
|
// 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试)
|
||||||
until := time.Now().Add(tokenRefreshTempUnschedDuration)
|
until := time.Now().Add(tokenRefreshTempUnschedDuration)
|
||||||
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
|
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
|
||||||
|
s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted")
|
||||||
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
|
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
|
||||||
slog.Warn("token_refresh.set_temp_unschedulable_failed",
|
slog.Warn("token_refresh.set_temp_unschedulable_failed",
|
||||||
"account_id", account.ID,
|
"account_id", account.ID,
|
||||||
@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||||
|
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||||
@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||||
|
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||||
}
|
}
|
||||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||||
if s.tempUnschedCache != nil {
|
if s.tempUnschedCache != nil {
|
||||||
|
|||||||
@ -59,6 +59,7 @@ func ProvideTokenRefreshService(
|
|||||||
privacyClientFactory PrivacyClientFactory,
|
privacyClientFactory PrivacyClientFactory,
|
||||||
proxyRepo ProxyRepository,
|
proxyRepo ProxyRepository,
|
||||||
refreshAPI *OAuthRefreshAPI,
|
refreshAPI *OAuthRefreshAPI,
|
||||||
|
runtimeBlocker AccountRuntimeBlocker,
|
||||||
) *TokenRefreshService {
|
) *TokenRefreshService {
|
||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||||
// 注入 OpenAI privacy opt-out 依赖
|
// 注入 OpenAI privacy opt-out 依赖
|
||||||
@ -67,6 +68,7 @@ func ProvideTokenRefreshService(
|
|||||||
svc.SetRefreshAPI(refreshAPI)
|
svc.SetRefreshAPI(refreshAPI)
|
||||||
// 调用侧显式注入后台刷新策略,避免策略漂移
|
// 调用侧显式注入后台刷新策略,避免策略漂移
|
||||||
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
||||||
|
svc.SetAccountRuntimeBlocker(runtimeBlocker)
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@ -183,6 +185,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
|||||||
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
||||||
}
|
}
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
|
svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond)
|
||||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||||
}
|
}
|
||||||
return svc
|
return svc
|
||||||
@ -455,6 +458,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAdminService,
|
NewAdminService,
|
||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
NewOpenAIGatewayService,
|
NewOpenAIGatewayService,
|
||||||
|
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
|
||||||
NewOAuthService,
|
NewOAuthService,
|
||||||
NewOpenAIOAuthService,
|
NewOpenAIOAuthService,
|
||||||
NewGeminiOAuthService,
|
NewGeminiOAuthService,
|
||||||
|
|||||||
@ -405,6 +405,9 @@ gateway:
|
|||||||
# Enable batch load calculation for scheduling
|
# Enable batch load calculation for scheduling
|
||||||
# 启用调度批量负载计算
|
# 启用调度批量负载计算
|
||||||
load_batch_enabled: true
|
load_batch_enabled: true
|
||||||
|
# Tiny in-process TTL for batch load reads in milliseconds (0 disables)
|
||||||
|
# 调度批量负载读取的进程内短缓存 TTL(毫秒,0 表示禁用)
|
||||||
|
load_batch_cache_ttl_ms: 200
|
||||||
# Slot cleanup interval (duration)
|
# Slot cleanup interval (duration)
|
||||||
# 并发槽位清理周期(时间段)
|
# 并发槽位清理周期(时间段)
|
||||||
slot_cleanup_interval: 30s
|
slot_cleanup_interval: 30s
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user