fix: optimize OpenAI account cooldown scheduling

This commit is contained in:
shaw 2026-05-23 10:18:43 +08:00
parent f59d9a5f8e
commit 1e406fed52
29 changed files with 1169 additions and 370 deletions

View File

@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
@ -173,24 +188,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@ -261,7 +261,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)

View File

@ -1009,7 +1009,8 @@ type GatewaySchedulingConfig struct {
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"`
// 快照桶读取时的 MGET 分块大小
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
// 快照重建时的缓存写入分块大小
@ -1828,6 +1829,7 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200)
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
@ -2634,6 +2636,9 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 {
return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative")
}
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
}

View File

@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 {
t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS)
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling load batch cache ttl",
mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 },
wantErr: "gateway.scheduling.load_batch_cache_ttl_ms",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },

View File

@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if channelMappingMsg.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -258,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -1258,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)

View File

@ -531,6 +531,7 @@ type adminServiceImpl struct {
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
privacyClientFactory PrivacyClientFactory
runtimeBlocker AccountRuntimeBlocker
}
type userGroupRateBatchReader interface {
@ -556,6 +557,7 @@ func NewAdminService(
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
privacyClientFactory PrivacyClientFactory,
runtimeBlocker AccountRuntimeBlocker,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@ -575,6 +577,7 @@ func NewAdminService(
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
privacyClientFactory: privacyClientFactory,
runtimeBlocker: runtimeBlocker,
}
}
@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
return nil, err
}
if s.runtimeBlocker != nil {
s.runtimeBlocker.ClearAccountSchedulingBlock(id)
}
return s.accountRepo.GetByID(ctx, id)
}

View File

@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
TempUnschedulableReason: "missing refresh token",
},
}
svc := &adminServiceImpl{accountRepo: repo}
blocker := &runtimeBlockRecorder{}
svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker}
updated, err := svc.ClearAccountError(context.Background(), 31)
require.NoError(t, err)
@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
require.Nil(t, updated.RateLimitResetAt)
require.Nil(t, updated.TempUnschedulableUntil)
require.Empty(t, updated.TempUnschedulableReason)
require.Equal(t, []int64{31}, blocker.clearedIDs)
}

View File

@ -3,13 +3,17 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// ConcurrencyCache 定义并发控制的缓存接口
@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error
}
const (
// Default extra wait slots beyond concurrency limit
// 默认等待队列额外槽位
defaultExtraWaitSlots = 20
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
accountLoadBatchFetchTimeout = 3 * time.Second
maxAccountLoadBatchCacheEntries = 256
)
// ConcurrencyService manages concurrent request limiting for accounts and users
// ConcurrencyService 管理账号和用户的并发限制。
type ConcurrencyService struct {
cache ConcurrencyCache
accountLoadCacheTTL atomic.Int64
accountLoadCacheMu sync.RWMutex
accountLoadCache map[string]cachedAccountLoadBatch
accountLoadGroup singleflight.Group
}
// NewConcurrencyService creates a new ConcurrencyService
type cachedAccountLoadBatch struct {
loadMap map[int64]*AccountLoadInfo
expiresAt time.Time
}
// NewConcurrencyService 创建并发控制服务。
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache}
svc := &ConcurrencyService{
cache: cache,
accountLoadCache: make(map[string]cachedAccountLoadBatch),
}
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
return svc
}
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
if s == nil {
return
}
s.accountLoadCacheTTL.Store(int64(ttl))
if ttl <= 0 {
s.accountLoadCacheMu.Lock()
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
s.accountLoadCacheMu.Unlock()
}
}
// AcquireResult represents the result of acquiring a concurrency slot
@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
// GetAccountsLoadBatch 批量获取账号负载信息。
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return s.getAccountsLoadBatch(ctx, accounts, true)
}
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return s.getAccountsLoadBatch(ctx, accounts, false)
}
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*AccountLoadInfo{}, nil
}
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
ttl := time.Duration(s.accountLoadCacheTTL.Load())
if !allowCache || ttl <= 0 {
return s.fetchAccountsLoadBatch(ctx, accounts)
}
key := accountLoadBatchCacheKey(accounts)
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
return cached, nil
}
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
now := time.Now()
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
return cached, nil
}
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
if fetchErr != nil {
return nil, fetchErr
}
cached := cloneAccountLoadMap(loadMap)
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
return cached, nil
})
if err != nil {
return nil, err
}
loadMap, _ := value.(map[int64]*AccountLoadInfo)
if loadMap == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return loadMap, nil
}
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
baseCtx := context.Background()
if ctx != nil {
baseCtx = context.WithoutCancel(ctx)
}
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
defer cancel()
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
}
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
s.accountLoadCacheMu.RLock()
cached, ok := s.accountLoadCache[key]
s.accountLoadCacheMu.RUnlock()
if !ok {
return nil, false
}
if !now.Before(cached.expiresAt) {
s.accountLoadCacheMu.Lock()
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
delete(s.accountLoadCache, key)
}
s.accountLoadCacheMu.Unlock()
return nil, false
}
return cached.loadMap, true
}
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
s.accountLoadCacheMu.Lock()
if s.accountLoadCache == nil {
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
}
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
now := time.Now()
for cacheKey, cached := range s.accountLoadCache {
if !now.Before(cached.expiresAt) {
delete(s.accountLoadCache, cacheKey)
}
}
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
for cacheKey := range s.accountLoadCache {
delete(s.accountLoadCache, cacheKey)
break
}
}
}
s.accountLoadCache[key] = cachedAccountLoadBatch{
loadMap: loadMap,
expiresAt: expiresAt,
}
s.accountLoadCacheMu.Unlock()
}
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
hash := sha256.New()
var buf [16]byte
for _, account := range accounts {
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
_, _ = hash.Write(buf[:])
}
sum := hash.Sum(nil)
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
}
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
if len(loadMap) == 0 {
return map[int64]*AccountLoadInfo{}
}
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
for accountID, loadInfo := range loadMap {
if loadInfo == nil {
clone[accountID] = nil
continue
}
copied := *loadInfo
clone[accountID] = &copied
}
return clone
}
// GetUsersLoadBatch returns load info for multiple users.

View File

@ -7,7 +7,9 @@ import (
"errors"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct {
// 记录调用
releasedAccountIDs []int64
releasedRequestIDs []string
loadBatchCalls atomic.Int64
}
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
c.loadBatchCalls.Add(1)
return c.loadBatch, c.loadBatchErr
}
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
require.Empty(t, result)
}
func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) {
cache := &stubConcurrencyCacheForTest{
loadBatch: map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
},
}
svc := NewConcurrencyService(cache)
svc.SetAccountLoadBatchCacheTTL(time.Second)
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
first, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 1, first[int64(1)].CurrentConcurrency)
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
second, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 1, second[int64(1)].CurrentConcurrency)
require.Equal(t, int64(1), cache.loadBatchCalls.Load())
}
func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) {
cache := &stubConcurrencyCacheForTest{
loadBatch: map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
},
}
svc := NewConcurrencyService(cache)
svc.SetAccountLoadBatchCacheTTL(time.Second)
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
_, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency)
require.Equal(t, int64(2), cache.loadBatchCalls.Load())
}
func TestIncrementWaitCount_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
svc := NewConcurrencyService(cache)

View File

@ -0,0 +1,169 @@
package service
import (
"context"
"net/http"
"time"
)
const (
openAIAccountStateUpdateTimeout = 5 * time.Second
openAIOAuth429FallbackCooldown = 5 * time.Second
openAIStopSchedulingBridgeCooldown = 2 * time.Minute
openAIOAuth429StormWindow = 10 * time.Second
openAIOAuth429StormThreshold = 20
openAIOAuth429StormMaxAccountSwitches = 1
)
func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) {
base := context.Background()
if ctx != nil {
base = context.WithoutCancel(ctx)
}
return context.WithTimeout(base, openAIAccountStateUpdateTimeout)
}
func isOpenAIOAuthAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
func isOpenAIAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI
}
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
stateCtx, cancel := openAIAccountStateContext(ctx)
defer cancel()
if statusCode == http.StatusTooManyRequests {
s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody)
}
if s == nil || account == nil || s.rateLimitService == nil {
return false
}
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
if shouldDisable {
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
}
return shouldDisable
}
func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
if s == nil || !isOpenAIOAuthAccount(account) {
return
}
s.recordOpenAIOAuth429()
cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown)
if s.rateLimitService != nil {
if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) {
cooldownUntil = *resetAt
} else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil {
if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) {
cooldownUntil = resetAt
}
} else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 {
cooldownUntil = time.Now().Add(cooldown)
}
}
s.BlockAccountScheduling(account, cooldownUntil, "429")
}
func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) {
if s == nil || !isOpenAIAccount(account) {
return
}
now := time.Now()
blockUntil := until
if blockUntil.IsZero() || !blockUntil.After(now) {
blockUntil = now.Add(openAIStopSchedulingBridgeCooldown)
}
for {
current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
if !loaded {
actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil)
if !stored {
return
}
current = actual
}
currentUntil, ok := current.(time.Time)
if !ok || currentUntil.IsZero() {
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
return
}
continue
}
if currentUntil.After(blockUntil) {
return
}
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
return
}
}
}
func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) {
if s == nil || accountID <= 0 {
return
}
s.openaiAccountRuntimeBlockUntil.Delete(accountID)
}
func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool {
if s == nil || !isOpenAIAccount(account) {
return false
}
value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
if !ok {
return false
}
cooldownUntil, ok := value.(time.Time)
if !ok || cooldownUntil.IsZero() {
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
return false
}
if time.Now().Before(cooldownUntil) {
return true
}
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
return false
}
func (s *OpenAIGatewayService) recordOpenAIOAuth429() {
if s == nil {
return
}
now := time.Now()
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) {
s.openaiOAuth429WindowCount.Store(1)
return
}
}
s.openaiOAuth429WindowCount.Add(1)
}
func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool {
if s == nil {
return false
}
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
return false
}
return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold
}
func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool {
if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches {
return false
}
if !isOpenAIOAuthAccount(account) {
return false
}
return s.isOpenAIOAuth429Storm()
}

View File

@ -0,0 +1,101 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil)
apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil)
require.False(t, shouldDisable)
require.False(t, apiKeyShouldDisable)
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount))
}
func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) {
gateway := &OpenAIGatewayService{}
repo := &rateLimitAccountRepoStub{}
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
rateLimitService.SetAccountRuntimeBlocker(gateway)
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden"))
require.True(t, shouldDisable)
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
longUntil := time.Now().Add(10 * time.Minute)
svc.BlockAccountScheduling(account, longUntil, "oauth_401")
svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID)
require.True(t, ok)
actualUntil, ok := value.(time.Time)
require.True(t, ok)
require.WithinDuration(t, longUntil, actualUntil, time.Second)
}
func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429")
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
svc.ClearAccountSchedulingBlock(account.ID)
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
for i := 0; i < openAIOAuth429StormThreshold; i++ {
svc.recordOpenAIOAuth429()
}
require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0))
}

View File

@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct {
loadSkewMilliTotal atomic.Int64
}
type openAIAccountLoadPlan struct {
allCandidates []openAIAccountCandidateScore
candidates []openAIAccountCandidateScore
staleSnapshotCompactRetry []openAIAccountCandidateScore
selectionOrder []openAIAccountCandidateScore
candidateCount int
topK int
loadSkew float64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
if acquireErr == nil && result != nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
@ -586,6 +596,231 @@ func buildOpenAIWeightedSelectionOrder(
return order
}
func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan(
req OpenAIAccountScheduleRequest,
filtered []*Account,
loadMap map[int64]*AccountLoadInfo,
) openAIAccountLoadPlan {
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
errorRate, ttft, hasTTFT := 0.0, 0.0, false
if s.stats != nil {
errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID)
}
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
}
plan := openAIAccountLoadPlan{
allCandidates: allCandidates,
candidates: candidates,
staleSnapshotCompactRetry: staleSnapshotCompactRetry,
candidateCount: len(candidates),
}
if len(candidates) == 0 {
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
return plan
}
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
plan.candidates = candidates
plan.topK = s.service.openAIWSLBTopK()
if plan.topK > len(candidates) {
plan.topK = len(candidates)
}
if plan.topK <= 0 {
plan.topK = 1
}
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
return plan
}
func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder(
req OpenAIAccountScheduleRequest,
plan openAIAccountLoadPlan,
) []openAIAccountCandidateScore {
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || plan.topK <= 0 {
return nil
}
groupTopK := plan.topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
for _, candidate := range plan.candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates))
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...)
}
return selectionOrder
}
return buildSelectionOrder(plan.candidates)
}
func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
ctx context.Context,
req OpenAIAccountScheduleRequest,
selectionOrder []openAIAccountCandidateScore,
) (*AccountSelectionResult, bool, error) {
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, compactBlocked, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, compactBlocked, nil
}
}
return nil, compactBlocked, nil
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
@ -616,8 +851,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if s.service.isOpenAIAccountRuntimeBlocked(account) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set")
_ = s.service.accountRepo.SetError(ctx, account.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
@ -645,208 +884,46 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap)
candidateCount := plan.candidateCount
topK := plan.topK
loadSkew := plan.loadSkew
selectionOrder := plan.selectionOrder
if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
// 时作为最后兜底snapshot 可能已陈旧)。
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
}
candidateCount := len(candidates)
loadSkew := 0.0
if len(candidates) > 0 {
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
}
topK := 0
if len(candidates) > 0 {
topK = s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
}
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || topK <= 0 {
return nil
}
groupTopK := topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
for _, candidate := range candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
}
} else {
selectionOrder = buildSelectionOrder(candidates)
if req.RequireCompact && len(selectionOrder) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
if len(selectionOrder) == 0 {
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0)
}
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder)
if acquireErr != nil {
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil {
return result, candidateCount, topK, loadSkew, nil
}
if s.service.concurrencyService != nil {
if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil {
freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap)
if len(freshPlan.selectionOrder) > 0 {
freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder)
if freshAcquireErr != nil {
return nil, candidateCount, topK, loadSkew, freshAcquireErr
}
if freshResult != nil {
return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil
}
compactBlocked = compactBlocked || freshCompactBlocked
selectionOrder = freshPlan.selectionOrder
candidateCount = freshPlan.candidateCount
topK = freshPlan.topK
loadSkew = freshPlan.loadSkew
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, candidateCount, topK, loadSkew, nil
}
}
@ -893,6 +970,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
if account == nil {
return false
}
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
return false
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false
}

View File

@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -354,6 +354,9 @@ type OpenAIGatewayService struct {
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time
openaiOAuth429WindowStartUnixNano atomic.Int64
openaiOAuth429WindowCount atomic.Int64
openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle
@ -417,6 +420,12 @@ func NewOpenAIGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
if rateLimitService != nil {
rateLimitService.SetAccountRuntimeBlocker(svc)
}
if openAITokenProvider != nil {
openAITokenProvider.SetAccountRuntimeBlocker(svc)
}
svc.logOpenAIWSModeBootstrap()
return svc
}
@ -1381,13 +1390,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
}
hydrated, err := s.hydrateSelectedAccount(ctx, selected)
if err != nil {
return nil, err
}
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
}
return s.hydrateSelectedAccount(ctx, selected)
return hydrated, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
@ -1430,6 +1444,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(account) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
@ -1575,8 +1593,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
if err == nil && result != nil && result.Acquired {
return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
@ -1627,13 +1645,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if s.isOpenAIAccountRuntimeBlocked(account) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
if selectErr != nil {
return nil, selectErr
}
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
return selection, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@ -1665,6 +1689,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if !acc.IsSchedulable() {
continue
}
if s.isOpenAIAccountRuntimeBlocked(acc) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
@ -1687,6 +1714,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
})
}
tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) {
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) == 0 {
return nil, false, nil
}
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
selectionOrder := make([]accountWithLoad, 0, len(available))
if requireCompact {
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
for _, item := range available {
if openAICompactSupportTier(item.account) == tier {
out = append(out, item)
}
}
return out
}
selectionOrder = appendTier(selectionOrder, 2)
selectionOrder = appendTier(selectionOrder, 1)
// tier 0 候选作为兜底追加DB recheck 时若发现 cache tier 0 实际
// 已升级为 1/2探测刚跑完cache 尚未刷新),仍可正常命中。
selectionOrder = appendTier(selectionOrder, 0)
} else {
selectionOrder = append(selectionOrder, available...)
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
if selectErr != nil {
return nil, true, selectErr
}
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return selection, true, nil
}
}
return nil, true, nil
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
ordered := append([]*Account(nil), candidates...)
@ -1707,87 +1820,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
if selectErr != nil {
return nil, selectErr
}
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
return selection, nil
}
}
} else {
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
selectionOrder := make([]accountWithLoad, 0, len(available))
if requireCompact {
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
for _, item := range available {
if openAICompactSupportTier(item.account) == tier {
out = append(out, item)
}
}
return out
}
selectionOrder = appendTier(selectionOrder, 2)
selectionOrder = appendTier(selectionOrder, 1)
// tier 0 候选作为兜底追加DB recheck 时若发现 cache tier 0 实际
// 已升级为 1/2探测刚跑完cache 尚未刷新),仍可正常命中。
selectionOrder = appendTier(selectionOrder, 0)
} else {
selectionOrder = append(selectionOrder, available...)
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil {
return nil, selectErr
} else if selection != nil {
return selection, nil
} else if attempted {
if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil {
if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil {
return nil, selectErr
} else if selection != nil {
return selection, nil
}
}
}
@ -1868,6 +1922,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(fresh) {
return nil
}
return fresh
}
@ -1889,6 +1946,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(latest) {
return nil
}
return latest
}
@ -1935,6 +1995,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
}, nil
}
func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) {
selection, err := s.newSelectionResult(ctx, account, true, release, nil)
if err != nil && release != nil {
release()
}
return selection, err
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
@ -1996,7 +2064,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// Forward forwards request to OpenAI API
@ -3278,9 +3346,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -3321,12 +3387,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
// 避免粘性路由继续复用刚被限流的账号。
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -4075,10 +4138,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
}
// Handle upstream error (mark account status)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
kind := "http_error"
if shouldDisable {
kind = "failover"
@ -4210,12 +4270,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
kind := "http_error"
if shouldDisable {
kind = "failover"

View File

@ -80,6 +80,7 @@ type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
runtimeBlocker AccountRuntimeBlocker
metrics *openAITokenRuntimeMetricsStore
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
p.runtimeBlocker = blocker
}
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
if p == nil {
return OpenAITokenRuntimeMetrics{}
@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account
if p == nil || p.accountRepo == nil || account == nil {
return
}
if p.runtimeBlocker != nil {
p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token")
}
bgCtx := context.Background()
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
slog.Warn("openai_token_provider.set_error_failed",

View File

@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
cache.getErr = errors.New("simulated cache miss")
provider := NewOpenAITokenProvider(repo, cache, nil)
blocker := &runtimeBlockRecorder{}
provider.SetAccountRuntimeBlocker(blocker)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
require.Len(t, blocker.accounts, 1)
require.Equal(t, account.ID, blocker.accounts[0].ID)
require.Equal(t, "missing_refresh_token", blocker.reasons[0])
}

View File

@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
}
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {

View File

@ -28,10 +28,16 @@ type RateLimitService struct {
openAI403CounterCache OpenAI403CounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
runtimeBlocker AccountRuntimeBlocker
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type AccountRuntimeBlocker interface {
BlockAccountScheduling(account *Account, until time.Time, reason string)
ClearAccountSchedulingBlock(accountID int64)
}
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
type SuccessfulTestRecoveryResult struct {
ClearedError bool
@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
s.tokenCacheInvalidator = invalidator
}
func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
s.runtimeBlocker = blocker
}
func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
if s == nil || s.runtimeBlocker == nil || account == nil {
return
}
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
}
func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) {
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
return
}
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
}
// ErrorPolicyResult 表示错误策略检查的结果
type ErrorPolicyResult int
@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
s.notifyAccountSchedulingBlocked(account, until, "oauth_401")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
}
@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error")
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code")
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
s.persistOpenAICodexSnapshot(ctx, account, headers)
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
s.notifyAccountSchedulingBlocked(account, *resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 2. Anthropic 平台:尝试解析 per-window 头5h / 7d选择实际触发的窗口
if result := calculateAnthropic429ResetTime(headers); result != nil {
s.notifyAccountSchedulingBlocked(account, result.resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
resetAt := time.Unix(ts, 0)
// 标记限流状态
s.notifyAccountSchedulingBlocked(account, resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun
resetAt := time.Now().Add(cooldown)
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
s.notifyAccountSchedulingBlocked(account, until, "529")
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
}
}
s.ResetOpenAI403Counter(ctx, accountID)
s.notifyAccountSchedulingBlockCleared(accountID)
return nil
}
@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
}
if result.ClearedError || result.ClearedRateLimit {
s.ResetOpenAI403Counter(ctx, accountID)
if result.ClearedError && !result.ClearedRateLimit {
s.notifyAccountSchedulingBlockCleared(accountID)
}
}
return result, nil
@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
}
s.notifyAccountSchedulingBlockCleared(accountID)
return nil
}
@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
reason = strings.TrimSpace(state.ErrorMessage)
}
s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
reason = state.ErrorMessage
}
s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error")
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false

View File

@ -6,16 +6,36 @@ import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type runtimeBlockRecorder struct {
accounts []*Account
until []time.Time
reasons []string
clearedIDs []int64
}
func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) {
r.accounts = append(r.accounts, account)
r.until = append(r.until, until)
r.reasons = append(r.reasons, reason)
}
func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) {
r.clearedIDs = append(r.clearedIDs, accountID)
}
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{1}}
blocker := &runtimeBlockRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
service.SetAccountRuntimeBlocker(blocker)
account := &Account{
ID: 301,
Platform: PlatformOpenAI,
@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable
require.Equal(t, 1, repo.tempCalls)
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
require.Contains(t, repo.lastTempReason, "(1/3)")
require.Len(t, blocker.accounts, 1)
require.Equal(t, account.ID, blocker.accounts[0].ID)
require.Equal(t, "openai_403_temp", blocker.reasons[0])
require.True(t, blocker.until[0].After(time.Now()))
}
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {

View File

@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
},
}
cache := &tempUnschedCacheRecorder{}
blocker := &runtimeBlockRecorder{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
svc.SetAccountRuntimeBlocker(blocker)
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
require.NoError(t, err)
@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
require.Equal(t, 1, repo.clearModelRateLimitCalls)
require.Equal(t, 1, repo.clearTempUnschedCalls)
require.Equal(t, []int64{42}, cache.deletedIDs)
require.Equal(t, []int64{42}, blocker.clearedIDs)
}
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {

View File

@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul
}
}
func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) {
cache := &snapshotHydrationCache{
accounts: map[int64]*Account{},
}
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil)
svc := &OpenAIGatewayService{
schedulerSnapshot: schedulerSnapshot,
}
releaseCalls := 0
selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() {
releaseCalls++
})
if err == nil {
t.Fatalf("expected hydration error")
}
if selection != nil {
t.Fatalf("expected nil selection on hydration error")
}
if releaseCalls != 1 {
t.Fatalf("expected release to be called once, got %d", releaseCalls)
}
}
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
cache := &snapshotHydrationCache{
snapshot: []*Account{

View File

@ -27,6 +27,7 @@ type TokenRefreshService struct {
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
refreshAPI *OAuthRefreshAPI // 统一刷新 API
runtimeBlocker AccountRuntimeBlocker
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
privacyClientFactory PrivacyClientFactory
@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
s.refreshPolicy = policy
}
func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
s.runtimeBlocker = blocker
}
func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
if s == nil || s.runtimeBlocker == nil || account == nil {
return
}
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
}
func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) {
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
return
}
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 不可重试错误invalid_grant/invalid_client 等)直接标记 error 状态并返回
if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable")
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 设置临时不可调度 10 分钟(不标记 error保持 status=active 让下个刷新周期能继续尝试)
until := time.Now().Add(tokenRefreshTempUnschedDuration)
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted")
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
slog.Warn("token_refresh.set_temp_unschedulable_failed",
"account_id", account.ID,
@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
)
} else {
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
s.notifyAccountSchedulingBlockCleared(account.ID)
}
}
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
s.notifyAccountSchedulingBlockCleared(account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {

View File

@ -59,6 +59,7 @@ func ProvideTokenRefreshService(
privacyClientFactory PrivacyClientFactory,
proxyRepo ProxyRepository,
refreshAPI *OAuthRefreshAPI,
runtimeBlocker AccountRuntimeBlocker,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 OpenAI privacy opt-out 依赖
@ -67,6 +68,7 @@ func ProvideTokenRefreshService(
svc.SetRefreshAPI(refreshAPI)
// 调用侧显式注入后台刷新策略,避免策略漂移
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
svc.SetAccountRuntimeBlocker(runtimeBlocker)
svc.Start()
return svc
}
@ -183,6 +185,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil {
svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond)
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}
return svc
@ -455,6 +458,7 @@ var ProviderSet = wire.NewSet(
NewAdminService,
NewGatewayService,
NewOpenAIGatewayService,
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,

View File

@ -405,6 +405,9 @@ gateway:
# Enable batch load calculation for scheduling
# 启用调度批量负载计算
load_batch_enabled: true
# Tiny in-process TTL for batch load reads in milliseconds (0 disables)
# 调度批量负载读取的进程内短缓存 TTL毫秒0 表示禁用)
load_batch_cache_ttl_ms: 200
# Slot cleanup interval (duration)
# 并发槽位清理周期(时间段)
slot_cleanup_interval: 30s