feat: merge feat/omniroute-ideas — P2C scheduler, quota scoring, tier fallback

This commit is contained in:
win 2026-04-29 15:42:37 +08:00
commit fdd2d08a4d
37 changed files with 2360 additions and 70 deletions

View File

@ -188,7 +188,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@ -203,7 +204,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService) requestEventBus := service.NewRequestEventBus()
opsHandler := admin.NewOpsHandler(opsService, requestEventBus)
updateCache := repository.NewUpdateCache(redisClient) updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo) serviceBuildInfo := provideServiceBuildInfo(buildInfo)
@ -244,7 +246,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService) totpHandler := handler.NewTotpHandler(totpService)
@ -257,7 +259,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient) healthService := service.NewHealthService(db, redisClient)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)

View File

@ -691,6 +691,14 @@ type GatewayConfig struct {
// UserMessageQueue: 用户消息串行队列配置 // UserMessageQueue: 用户消息串行队列配置
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
// RPMSmoothing: RPM 令牌桶平滑配置
// 启用后RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
// ContextCompression: 主动上下文压缩配置
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
} }
type GatewayAntigravityLSWorkerConfig struct { type GatewayAntigravityLSWorkerConfig struct {
@ -745,6 +753,37 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
return "" return ""
} }
// RPMSmoothingConfig RPM 令牌桶平滑配置
type RPMSmoothingConfig struct {
// Enabled: 是否启用 RPM 令牌桶平滑(默认 false
// 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。
Enabled bool `mapstructure:"enabled"`
// MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429默认 5000
MaxWaitMS int `mapstructure:"max_wait_ms"`
}
// MaxWait returns the configured wait duration, defaulting to 5s.
func (c *RPMSmoothingConfig) MaxWait() time.Duration {
if c.MaxWaitMS <= 0 {
return 5 * time.Second
}
return time.Duration(c.MaxWaitMS) * time.Millisecond
}
// ContextCompressionConfig 主动上下文压缩配置
type ContextCompressionConfig struct {
// MaxTokens: 压缩目标 token 数chars/4 近似),超出时从最旧消息开始裁剪(默认 190000
MaxTokens int `mapstructure:"max_tokens"`
}
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
func (c *ContextCompressionConfig) GetMaxTokens() int {
if c.MaxTokens <= 0 {
return 190_000
}
return c.MaxTokens
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct { type GatewayOpenAIWSConfig struct {
@ -828,6 +867,8 @@ type GatewayOpenAIWSConfig struct {
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
// EnableP2CScheduling: 启用 Power-of-Two-Choices 调度(默认 false使用 top-K 加权随机)
EnableP2CScheduling bool `mapstructure:"enable_p2c_scheduling"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
} }
@ -838,6 +879,8 @@ type GatewayOpenAIWSSchedulerScoreWeights struct {
Queue float64 `mapstructure:"queue"` Queue float64 `mapstructure:"queue"`
ErrorRate float64 `mapstructure:"error_rate"` ErrorRate float64 `mapstructure:"error_rate"`
TTFT float64 `mapstructure:"ttft"` TTFT float64 `mapstructure:"ttft"`
// Quota: 剩余配额比例权重0 表示不参与打分)
Quota float64 `mapstructure:"quota"`
} }
// GatewayUsageRecordConfig 使用量记录异步队列配置 // GatewayUsageRecordConfig 使用量记录异步队列配置
@ -992,6 +1035,10 @@ type GatewaySchedulingConfig struct {
// 全量重建周期配置 // 全量重建周期配置
// 全量重建周期0 表示禁用 // 全量重建周期0 表示禁用
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"` FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
// EnableTierFallbackChain: 启用跨档降级链(订阅 → API Key → Bedrock默认 false
// 仅对 Anthropic 平台生效;启用后账号按类型分层,优先使用订阅账号,依次降级。
EnableTierFallbackChain bool `mapstructure:"enable_tier_fallback_chain"`
} }
func (s *ServerConfig) Address() string { func (s *ServerConfig) Address() string {
@ -2504,7 +2551,8 @@ func (c *Config) Validate() error {
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Quota < 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
} }
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +

View File

@ -16,7 +16,8 @@ import (
) )
type OpsHandler struct { type OpsHandler struct {
opsService *service.OpsService opsService *service.OpsService
requestEventBus *service.RequestEventBus
} }
// GetErrorLogByID returns ops error log detail. // GetErrorLogByID returns ops error log detail.
@ -70,8 +71,8 @@ func parseOpsViewParam(c *gin.Context) string {
} }
} }
func NewOpsHandler(opsService *service.OpsService) *OpsHandler { func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler {
return &OpsHandler{opsService: opsService} return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus}
} }
// GetErrorLogs lists ops error logs. // GetErrorLogs lists ops error logs.

View File

@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService {
} }
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t)) h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, false) r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
} }
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t)) h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, false) r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
} }
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t)) h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, true) r := newOpsRuntimeRouter(h, true)
payload := map[string]any{ payload := map[string]any{

View File

@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
} }
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil) h := NewOpsHandler(nil, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) { func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{ svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false}, Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil) }, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true) r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true) r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true) r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) { func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true) r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{ svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false}, Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil) }, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true) r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
func TestOpsSystemLogHandler_Health(t *testing.T) { func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil) sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc) h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) {
} }
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) { func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil) h := NewOpsHandler(nil, nil)
r := newOpsSystemLogTestRouter(h, false) r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T
svc := service.NewOpsService(nil, nil, &config.Config{ svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false}, Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil) }, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc) h = NewOpsHandler(svc, nil)
r = newOpsSystemLogTestRouter(h, false) r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil) req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)

View File

@ -0,0 +1,198 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type requestStreamWSMessage struct {
Type string `json:"type"`
Data service.RequestEvent `json:"data"`
}
// RequestStreamWSHandler streams real-time request events to WebSocket clients.
// GET /api/v1/admin/ops/ws/requests
//
// Each connected client receives a JSON message per gateway dispatch:
//
// {"type":"request_event","data":{"timestamp":...,"method":"POST","path":"/v1/messages",
// "model":"claude-3-5-sonnet-20241022","account_id":42,"status":"success","latency_ms":1230}}
func (h *OpsHandler) RequestStreamWSHandler(c *gin.Context) {
clientIP := requestClientIP(c.Request)
if h == nil || h.opsService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
return
}
if h.requestEventBus == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "request event bus not initialized"})
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
return
}
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
return
}
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer func() {
if wsConnCount.Add(-1) == 0 {
scheduleQPSWSIdleStop()
}
}()
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] per-ip limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer releaseOpsWSIPSlot(clientIP)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] upgrade failed: %v", err)
return
}
defer func() { _ = conn.Close() }()
handleRequestStreamWebSocket(c.Request.Context(), conn, h.requestEventBus)
}
func handleRequestStreamWebSocket(parentCtx context.Context, conn *websocket.Conn, bus *service.RequestEventBus) {
if conn == nil || bus == nil {
return
}
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
subID, eventCh := bus.Subscribe()
defer bus.Unsubscribe(subID)
var closeOnce sync.Once
closeConn := func() {
closeOnce.Do(func() { _ = conn.Close() })
}
closeFrameCh := make(chan []byte, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
})
conn.SetCloseHandler(func(code int, text string) error {
select {
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
default:
}
cancel()
return nil
})
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] read failed: %v", err)
}
return
}
}
}()
pingTicker := time.NewTicker(qpsWSPingInterval)
defer pingTicker.Stop()
writeWithTimeout := func(messageType int, data []byte) error {
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
return err
}
return conn.WriteMessage(messageType, data)
}
sendClose := func(closeFrame []byte) {
if closeFrame == nil {
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
}
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
}
for {
select {
case evt, ok := <-eventCh:
if !ok {
// channel closed by Unsubscribe
sendClose(nil)
closeConn()
wg.Wait()
return
}
msg, err := json.Marshal(requestStreamWSMessage{Type: "request_event", Data: evt})
if err != nil {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case closeFrame := <-closeFrameCh:
sendClose(closeFrame)
closeConn()
wg.Wait()
return
case <-ctx.Done():
var closeFrame []byte
select {
case closeFrame = <-closeFrameCh:
default:
}
sendClose(closeFrame)
closeConn()
wg.Wait()
return
}
}
}

View File

@ -48,6 +48,7 @@ type GatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper userMsgQueueHelper *UserMsgQueueHelper
requestEventBus *service.RequestEventBus
maxAccountSwitches int maxAccountSwitches int
maxAccountSwitchesGemini int maxAccountSwitchesGemini int
cfg *config.Config cfg *config.Config
@ -70,6 +71,7 @@ func NewGatewayHandler(
userMsgQueueService *service.UserMessageQueueService, userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config, cfg *config.Config,
settingService *service.SettingService, settingService *service.SettingService,
requestEventBus *service.RequestEventBus,
) *GatewayHandler { ) *GatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
maxAccountSwitches := 10 maxAccountSwitches := 10
@ -103,6 +105,7 @@ func NewGatewayHandler(
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper, userMsgQueueHelper: umqHelper,
requestEventBus: requestEventBus,
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini, maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg, cfg: cfg,
@ -113,6 +116,7 @@ func NewGatewayHandler(
// Messages handles Claude API compatible messages endpoint // Messages handles Claude API compatible messages endpoint
// POST /v1/messages // POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) { func (h *GatewayHandler) Messages(c *gin.Context) {
reqStartTime := time.Now()
// 从context获取apiKey和userApiKeyAuth中间件已设置 // 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok { if !ok {
@ -164,6 +168,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 解析渠道级模型映射 // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 实时请求查看器:记录每次请求的结果(账号、模型、状态、延迟)
var (
reqEventAccountID int64
reqEventStatus = "error"
)
defer func() {
if h.requestEventBus != nil {
h.requestEventBus.Publish(service.RequestEvent{
Timestamp: reqStartTime,
Method: c.Request.Method,
Path: c.FullPath(),
Model: reqModel,
AccountID: reqEventAccountID,
Status: reqEventStatus,
LatencyMS: time.Since(reqStartTime).Milliseconds(),
})
}
}()
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
@ -406,6 +429,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmCancel()
if rpmErr != nil {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
reqEventAccountID = account.ID
reqEventStatus = "rate_limited"
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
@ -473,6 +513,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 实时请求查看器:标记 Gemini 路径成功
reqEventAccountID = account.ID
reqEventStatus = "success"
// RPM 计数递增Forward 成功后) // RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 // 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
@ -650,6 +694,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} }
} }
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmCancel()
if rpmErr != nil {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
reqEventAccountID = account.ID
reqEventStatus = "rate_limited"
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
@ -850,6 +911,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 实时请求查看器:标记 Anthropic 路径成功
reqEventAccountID = account.ID
reqEventStatus = "success"
// RPM 计数递增Forward 成功后) // RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 // 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。

View File

@ -29,6 +29,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
"golang.org/x/sync/singleflight"
entsql "entgo.io/ent/dialect/sql" entsql "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/dialect/sql/sqljson"
@ -49,6 +50,13 @@ type accountRepository struct {
// Used to proactively sync account snapshot to cache when status changes, // Used to proactively sync account snapshot to cache when status changes,
// ensuring sticky sessions can promptly detect unavailable accounts. // ensuring sticky sessions can promptly detect unavailable accounts.
schedulerCache service.SchedulerCache schedulerCache service.SchedulerCache
// tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。
// 上游 401/限流爆发时N 个 in-flight 请求会同时调用此方法;底层 SQL
// 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row但 N 次
// SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight
// 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。
tempUnschedSF singleflight.Group
} }
var schedulerNeutralExtraKeyPrefixes = []string{ var schedulerNeutralExtraKeyPrefixes = []string{
@ -1124,6 +1132,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
} }
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
// 进程内合并并发调用key 包含 until 和 reason确保不同窗口/原因独立去重。
// until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致;
// 哪怕略有偏差SQL 的 (existing < new) 条件保证语义安全。
sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason
_, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) {
return nil, r.setTempUnschedulableOnce(ctx, id, until, reason)
})
return err
}
func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error {
_, err := r.sql.ExecContext(ctx, ` _, err := r.sql.ExecContext(ctx, `
UPDATE accounts UPDATE accounts
SET temp_unschedulable_until = $1, SET temp_unschedulable_until = $1,

View File

@ -0,0 +1,119 @@
package repository
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。
// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在
// singleflight 的同一窗口内。
type blockingExecutor struct {
mu sync.Mutex
execCalls int32
queryCalls int32
release chan struct{}
concurrent int32
maxObserved int32
}
func newBlockingExecutor() *blockingExecutor {
return &blockingExecutor{release: make(chan struct{})}
}
func (e *blockingExecutor) Release() { close(e.release) }
func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) {
atomic.AddInt32(&e.execCalls, 1)
c := atomic.AddInt32(&e.concurrent, 1)
for {
old := atomic.LoadInt32(&e.maxObserved)
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
break
}
}
defer atomic.AddInt32(&e.concurrent, -1)
<-e.release
return driverResult{}, nil
}
func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) {
atomic.AddInt32(&e.queryCalls, 1)
return nil, sql.ErrNoRows
}
// driverResult 是一个零值 sql.Result用于测试。
type driverResult struct{}
func (driverResult) LastInsertId() (int64, error) { return 0, nil }
func (driverResult) RowsAffected() (int64, error) { return 1, nil }
func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) {
// 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际
// SQL 路径UPDATE + outbox INSERT = 2 次 ExecContext
exec := newBlockingExecutor()
repo := newAccountRepositoryWithSQL(nil, exec, nil)
const callers = 30
until := time.Now().Add(10 * time.Minute)
const reason = "OAuth 401: invalid_grant"
var wg sync.WaitGroup
wg.Add(callers)
for i := 0; i < callers; i++ {
go func() {
defer wg.Done()
_ = repo.SetTempUnschedulable(context.Background(), 42, until, reason)
}()
}
// 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
time.Sleep(5 * time.Millisecond)
}
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent),
"singleflight should serialize the SQL call to exactly one in-flight execution")
exec.Release()
wg.Wait()
// 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec其余 29 个 caller 共享结果。
require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2),
"expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls)
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved),
"no two SQL execs should run concurrently for the same singleflight key")
}
func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) {
// 不同 account 应分属不同 sf key能并行写库。
exec := newBlockingExecutor()
repo := newAccountRepositoryWithSQL(nil, exec, nil)
until := time.Now().Add(10 * time.Minute)
var wg sync.WaitGroup
for i := int64(1); i <= 3; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
_ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason")
}()
}
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
time.Sleep(5 * time.Millisecond)
}
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved),
"different accounts should be able to write in parallel")
exec.Release()
wg.Wait()
}

View File

@ -38,6 +38,7 @@ func ProvideRouter(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService, settingService *service.SettingService,
healthService *service.HealthService,
redisClient *redis.Client, redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
@ -95,7 +96,7 @@ func ProvideRouter(
service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
}) })
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
} }
// ProvideHTTPServer 提供 HTTP 服务器 // ProvideHTTPServer 提供 HTTP 服务器

View File

@ -30,6 +30,7 @@ func SetupRouter(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService, settingService *service.SettingService,
healthService *service.HealthService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client, redisClient *redis.Client,
) *gin.Engine { ) *gin.Engine {
@ -81,7 +82,7 @@ func SetupRouter(
} }
// 注册路由 // 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
return r return r
} }
@ -97,11 +98,12 @@ func registerRoutes(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
opsService *service.OpsService, opsService *service.OpsService,
settingService *service.SettingService, settingService *service.SettingService,
healthService *service.HealthService,
cfg *config.Config, cfg *config.Config,
redisClient *redis.Client, redisClient *redis.Client,
) { ) {
// 通用路由(健康检查、状态等) // 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r) routes.RegisterCommonRoutes(r, healthService)
// API v1 // API v1
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")

View File

@ -151,10 +151,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds) settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
} }
// WebSocket realtime (QPS/TPS) // WebSocket realtime (QPS/TPS and request stream)
ws := ops.Group("/ws") ws := ops.Group("/ws")
{ {
ws.GET("/qps", h.Admin.Ops.QPSWSHandler) ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
ws.GET("/requests", h.Admin.Ops.RequestStreamWSHandler)
} }
// Error logs (legacy) // Error logs (legacy)

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -16,11 +17,37 @@ const (
claudeCodeGrowthBookDateUpdated = "1970-01-01T00:00:00Z" claudeCodeGrowthBookDateUpdated = "1970-01-01T00:00:00Z"
) )
// RegisterCommonRoutes 注册通用路由(健康检查、状态等) // readinessHandlerTimeout 限定 readiness 端点对外的最大返回耗时。
func RegisterCommonRoutes(r *gin.Engine) { // HealthService 内部对每个组件再有独立超时,所以这里给宽一点即可。
// 健康检查 const readinessHandlerTimeout = 3 * time.Second
r.GET("/health", func(c *gin.Context) {
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)。
//
// 健康端点的语义分层:
// - /healthz : liveness 探针。零依赖、永远 200。容器/进程探活专用。
// - /ready : readiness 探针。检查 DB+Redis任一失败返回 503。
// - /health : 历史端点,等价于 /healthz保留向后兼容。
//
// dashboard 用的"业务健康分"由 ops_health_score 单独提供,与本路由无关。
func RegisterCommonRoutes(r *gin.Engine, healthService *service.HealthService) {
// Liveness仅证明进程在响应。
livenessHandler := func(c *gin.Context) {
_ = healthService.Liveness()
c.JSON(http.StatusOK, gin.H{"status": "ok"}) c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
r.GET("/healthz", livenessHandler)
r.GET("/health", livenessHandler) // 向后兼容旧的 docker-compose healthcheck
// Readiness检查关键依赖。失败时返回 503 但仍带详情,便于排障。
r.GET("/ready", func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), readinessHandlerTimeout)
defer cancel()
report := healthService.Readiness(ctx)
status := http.StatusOK
if !report.OK {
status = http.StatusServiceUnavailable
}
c.JSON(status, report)
}) })
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。 // Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。

View File

@ -7,16 +7,56 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
) )
func newCommonRoutesTestRouter() *gin.Engine { func newCommonRoutesTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
RegisterCommonRoutes(r) RegisterCommonRoutes(r, service.NewHealthService(nil, nil))
return r return r
} }
func newTestRouter(t *testing.T, hs *service.HealthService) *gin.Engine {
t.Helper()
gin.SetMode(gin.TestMode)
r := gin.New()
RegisterCommonRoutes(r, hs)
return r
}
func TestCommonRoutes_LivenessEndpoints(t *testing.T) {
r := newTestRouter(t, service.NewHealthService(nil, nil))
for _, path := range []string{"/healthz", "/health"} {
req := httptest.NewRequest(http.MethodGet, path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "liveness path %s should be 200", path)
}
}
func TestCommonRoutes_ReadyEndpoint_NoDepsReturnsOK(t *testing.T) {
// 没有 DB/Redis 依赖时 readiness 视为 ok早期启动场景
r := newTestRouter(t, service.NewHealthService(nil, nil))
req := httptest.NewRequest(http.MethodGet, "/ready", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Contains(t, w.Body.String(), "\"ok\":true")
}
func TestCommonRoutes_SetupStatusUnchanged(t *testing.T) {
// 验证我们没有破坏既有的 /setup/status 行为(前端依赖)。
r := newTestRouter(t, service.NewHealthService(nil, nil))
req := httptest.NewRequest(http.MethodGet, "/setup/status", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Contains(t, w.Body.String(), "needs_setup")
}
func TestCommonRoutes_ClaudeCodeBootstrap(t *testing.T) { func TestCommonRoutes_ClaudeCodeBootstrap(t *testing.T) {
r := newCommonRoutesTestRouter() r := newCommonRoutesTestRouter()

View File

@ -988,6 +988,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false return false
} }
// IsContextCompressionEnabled returns true if the account has opted into proactive
// context compression. When enabled, the gateway will trim oldest messages before
// dispatch to keep the estimated token count within the configured budget.
func (a *Account) IsContextCompressionEnabled() bool {
if a.Credentials == nil {
return false
}
enabled, _ := a.Credentials["enable_context_compression"].(bool)
return enabled
}
func (a *Account) IsBedrock() bool { func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
} }
@ -1572,6 +1583,24 @@ func (a *Account) GetQuotaUsed() float64 {
return a.getExtraFloat64("quota_used") return a.getExtraFloat64("quota_used")
} }
// GetQuotaRemainingFraction returns the fraction of total quota remaining in [0,1].
// Returns 1.0 when no quota limit is set (limit == 0 means unlimited).
func (a *Account) GetQuotaRemainingFraction() float64 {
limit := a.GetQuotaLimit()
if limit <= 0 {
return 1.0
}
used := a.GetQuotaUsed()
remaining := (limit - used) / limit
if remaining < 0 {
return 0
}
if remaining > 1 {
return 1
}
return remaining
}
// GetQuotaDailyLimit 获取日额度限制美元0 表示未启用 // GetQuotaDailyLimit 获取日额度限制美元0 表示未启用
func (a *Account) GetQuotaDailyLimit() float64 { func (a *Account) GetQuotaDailyLimit() float64 {
return a.getExtraFloat64("quota_daily_limit") return a.getExtraFloat64("quota_daily_limit")

View File

@ -40,6 +40,7 @@ func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
want bool want bool
}{ }{
{"标准版本号", "claude-cli/1.0.0", true}, {"标准版本号", "claude-cli/1.0.0", true},
{"官方 transport UA", "claude-code/2.1.88", true},
{"多位版本号", "claude-cli/12.34.56", true}, {"多位版本号", "claude-cli/12.34.56", true},
{"大写开头", "Claude-CLI/1.0.0", true}, {"大写开头", "Claude-CLI/1.0.0", true},
{"非 claude-cli", "curl/7.64.1", false}, {"非 claude-cli", "curl/7.64.1", false},
@ -90,6 +91,19 @@ func TestValidate_MessagesPath_FullValid(t *testing.T) {
require.True(t, result, "完整有效请求应通过") require.True(t, result, "完整有效请求应通过")
} }
func TestValidate_MessagesPath_FullValid_ClaudeCodeUA(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-code/2.1.88")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
req.Header.Set("anthropic-version", "2023-06-01")
result := v.Validate(req, validClaudeCodeBody())
require.True(t, result, "官方 transport/helper UA 也应通过")
}
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
v := newTestValidator() v := newTestValidator()
body := validClaudeCodeBody() body := validClaudeCodeBody()

View File

@ -15,11 +15,13 @@ import (
type ClaudeCodeValidator struct{} type ClaudeCodeValidator struct{}
var ( var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI大小写不敏感) // User-Agent 匹配: 官方 Claude Code 目前存在两类产品前缀:
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) // 1. 主 Anthropic API 客户端: claude-cli/x.y.z (...)
// 2. transport / helper 请求: claude-code/x.y.z
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/\d+\.\d+\.\d+`)
// 带捕获组的版本提取正则 // 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/(\d+\.\d+\.\d+)`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5 systemPromptThreshold = 0.5
@ -55,7 +57,7 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
// Validate 验证请求是否来自 Claude Code CLI // Validate 验证请求是否来自 Claude Code CLI
// 采用与 claude-relay-service 完全一致的验证策略: // 采用与 claude-relay-service 完全一致的验证策略:
// //
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x // Step 1: User-Agent 检查 (必需) - 必须是官方 claude-cli/ 或 claude-code/ 前缀
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过 // Step 2: 对于非 messages 路径,只要 UA 匹配就通过
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过UA 已验证) // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过UA 已验证)
// Step 4: 对于 messages 路径,进行严格验证: // Step 4: 对于 messages 路径,进行严格验证:

View File

@ -64,6 +64,7 @@ func TestExtractVersion(t *testing.T) {
want string want string
}{ }{
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"}, {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
{"claude-code/2.1.88", "2.1.88"},
{"claude-cli/1.0.0", "1.0.0"}, {"claude-cli/1.0.0", "1.0.0"},
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感 {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
{"curl/8.0.0", ""}, // 非 Claude CLI {"curl/8.0.0", ""}, // 非 Claude CLI

View File

@ -0,0 +1,151 @@
package service
import (
"encoding/json"
"math"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
const defaultContextCompressionMaxTokens = 190_000
// approxTokens estimates the token count for a string using the chars/4 heuristic.
func approxTokens(s string) int {
return int(math.Ceil(float64(len(s)) / 4.0))
}
// compressMessagesInBody trims the oldest messages from the request body so that the
// estimated token count of the messages array fits within maxTokens.
// Returns the original body unchanged if no compression is needed or if parsing fails.
func compressMessagesInBody(body []byte, maxTokens int) []byte {
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.Exists() || !msgsResult.IsArray() {
return body
}
// Unmarshal to a typed slice for processing.
var messages []map[string]any
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
return body
}
compressed, changed := compressMessages(messages, maxTokens)
if !changed {
return body
}
newMsgs, err := json.Marshal(compressed)
if err != nil {
return body
}
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
if err != nil {
return body
}
return updated
}
// compressMessages removes the oldest messages from the front of msgs until the
// estimated total token count is at or below maxTokens.
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
// to avoid orphaned tool_result blocks.
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
if len(msgs) == 0 {
return msgs, false
}
// Estimate total tokens.
totalTokens := 0
for _, m := range msgs {
totalTokens += msgTokens(m)
}
if totalTokens <= maxTokens {
return msgs, false
}
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
type unit struct {
startIdx int
endIdx int // exclusive
tokens int
}
units := make([]unit, 0, len(msgs))
i := 0
for i < len(msgs) {
toks := msgTokens(msgs[i])
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
toks += msgTokens(msgs[i+1])
units = append(units, unit{i, i + 2, toks})
i += 2
} else {
units = append(units, unit{i, i + 1, toks})
i++
}
}
// Remove units from the front until we are within budget.
// Always keep at least the last unit so we never send an empty messages array.
removeCount := 0
for removeCount < len(units)-1 && totalTokens > maxTokens {
totalTokens -= units[removeCount].tokens
removeCount++
}
if removeCount == 0 {
return msgs, false
}
cutIdx := units[removeCount].startIdx
return msgs[cutIdx:], true
}
// msgTokens estimates token count for a single message using the chars/4 heuristic.
func msgTokens(msg map[string]any) int {
b, err := json.Marshal(msg)
if err != nil {
return 0
}
return approxTokens(string(b))
}
// isAssistantWithToolUse returns true if msg is an assistant message whose content
// contains at least one block with "type": "tool_use".
func isAssistantWithToolUse(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "assistant" {
return false
}
return contentContainsType(msg["content"], "tool_use")
}
// isUserWithToolResult returns true if msg is a user message whose content
// contains at least one block with "type": "tool_result".
func isUserWithToolResult(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "user" {
return false
}
return contentContainsType(msg["content"], "tool_result")
}
// contentContainsType returns true if content (a []any of blocks) contains a block
// whose "type" field equals blockType.
func contentContainsType(content any, blockType string) bool {
blocks, ok := content.([]any)
if !ok {
return false
}
for _, b := range blocks {
block, ok := b.(map[string]any)
if !ok {
continue
}
if t, _ := block["type"].(string); t == blockType {
return true
}
}
return false
}

View File

@ -0,0 +1,195 @@
package service
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// helpers
func makeMsg(role, text string) map[string]any {
return map[string]any{
"role": role,
"content": text,
}
}
func makeToolUseMsg(id string) map[string]any {
return map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "tool_use",
"id": id,
"name": "search",
"input": map[string]any{},
},
},
}
}
func makeToolResultMsg(toolUseID string) map[string]any {
return map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": "result text",
},
},
}
}
func toAnySlice(msgs []map[string]any) []any {
out := make([]any, len(msgs))
for i, m := range msgs {
out[i] = m
}
return out
}
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
t.Helper()
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
require.NoError(t, err)
return b
}
// tests
func TestApproxTokens(t *testing.T) {
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
assert.Equal(t, 0, approxTokens(""))
}
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
msgs := []map[string]any{
makeMsg("user", "hi"),
makeMsg("assistant", "hello"),
}
result, changed := compressMessages(msgs, 100_000)
assert.False(t, changed)
assert.Len(t, result, 2)
}
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
// 10 messages, each large enough to be over a tight budget when combined.
msgs := make([]map[string]any, 10)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
}
// Force compression by using a very small token budget.
result, changed := compressMessages(msgs, 1)
assert.True(t, changed)
// Must keep at least one message (the last).
assert.GreaterOrEqual(t, len(result), 1)
// The remaining messages should be from the tail (newest).
lastOrig := msgs[len(msgs)-1]["content"]
lastResult := result[len(result)-1]["content"]
assert.Equal(t, lastOrig, lastResult)
}
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
// Messages: user → assistant+tool_use → user+tool_result → assistant
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "done"),
}
// Budget that forces removal of the first non-paired message but keeps the tool pair.
// Estimate total tokens and set budget to force removing only "start" but not the pair.
total := 0
for _, m := range msgs {
total += msgTokens(m)
}
// Budget: remove "start" but keep tool pair + "done".
startTokens := msgTokens(msgs[0])
budget := total - startTokens
result, changed := compressMessages(msgs, budget)
assert.True(t, changed)
// tool_use and tool_result should both be present or both absent.
hasToolUse := false
hasToolResult := false
for _, m := range result {
if isAssistantWithToolUse(m) {
hasToolUse = true
}
if isUserWithToolResult(m) {
hasToolResult = true
}
}
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
}
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
// Budget forces removal of the tool pair.
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "final answer after tool use"),
}
// Budget: only keep the last "assistant" message.
lastTokens := msgTokens(msgs[len(msgs)-1])
result, changed := compressMessages(msgs, lastTokens)
assert.True(t, changed)
// Neither tool_use nor tool_result should remain.
for _, m := range result {
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
}
}
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
result := compressMessagesInBody(body, 1)
assert.Equal(t, body, result, "body without messages should be unchanged")
}
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
msgs := []map[string]any{makeMsg("user", "hi")}
body := bodyWithMessages(t, msgs)
result := compressMessagesInBody(body, 100_000)
assert.Equal(t, body, result, "body under budget should be unchanged")
}
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
msgs := make([]map[string]any, 20)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
}
body := bodyWithMessages(t, msgs)
// Force significant compression.
result := compressMessagesInBody(body, 50)
assert.Less(t, len(result), len(body), "compressed body should be smaller")
// Resulting body should still be valid JSON with a messages array.
var parsed map[string]any
require.NoError(t, json.Unmarshal(result, &parsed))
resultMsgs, ok := parsed["messages"].([]any)
require.True(t, ok)
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
}

View File

@ -689,6 +689,41 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
} }
func TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
sessionID := "12345678-1234-1234-1234-123456789abc"
body, err := json.Marshal(map[string]any{
"model": "claude-3-7-sonnet-20250219",
"metadata": map[string]any{
"user_id": FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
sessionID,
claude.DefaultCLIProductVersion,
),
},
})
require.NoError(t, err)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
},
}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
}
req, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-3-7-sonnet-20250219", false, false)
require.NoError(t, err)
require.Equal(t, sessionID, getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"))
}
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) { func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil, nil,
nil, nil,
nil, nil,
nil,
) )
} }

View File

@ -335,8 +335,9 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@ -569,7 +570,8 @@ type GatewayService struct {
concurrencyService *ConcurrencyService concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
userGroupRateResolver *userGroupRateResolver userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group userGroupRateSF singleflight.Group
@ -614,6 +616,7 @@ func NewGatewayService(
channelService *ChannelService, channelService *ChannelService,
resolver *ModelPricingResolver, resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService, balanceNotifyService *BalanceNotifyService,
rpmTokenBucketSvc *RPMTokenBucketService,
) *GatewayService { ) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg)
@ -640,6 +643,7 @@ func NewGatewayService(
claudeTokenProvider: claudeTokenProvider, claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache, sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache, rpmCache: rpmCache,
rpmTokenBucket: rpmTokenBucketSvc,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService, settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCache: gocache.New(modelsListTTL, time.Minute),
@ -1374,6 +1378,9 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
if platform == PlatformAnthropic && s.enableTierFallbackChain() {
return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs)
}
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil { if err != nil {
return nil, err return nil, err
@ -2558,6 +2565,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
return err return err
} }
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
if s.rpmTokenBucket == nil {
return nil
}
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制 // checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号 // 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash // sessionID: 会话标识符(使用粘性会话的 hash
@ -3765,6 +3781,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return ParseMetadataUserID(metadataUserID) != nil return ParseMetadataUserID(metadataUserID) != nil
} }
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
if IsClaudeCodeClient(ctx) {
return true
}
if parsed == nil || c == nil {
return false
}
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型string / []any / nil // normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型string / []any / nil
// 避免 type switch 中 json.RawMessage底层 []byte无法匹配 case string / case []any / case nil 的问题。 // 避免 type switch 中 json.RawMessage底层 []byte无法匹配 case string / case []any / case nil 的问题。
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 // 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
@ -4323,6 +4349,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body) body = StripEmptyTextBlocks(body)
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
if account.IsContextCompressionEnabled() {
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
body = compressMessagesInBody(body, maxTok)
}
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body) setOpsUpstreamRequestBody(c, body)
@ -5923,15 +5955,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
// X-Claude-Code-Session-Id 头处理: // X-Claude-Code-Session-Id 头处理:
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id // Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
// 2. 客户端未提供mimic 模式)→ 生成确定性 per-account session UUID // 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
// 真实 CLI 每个请求都携带此 headerper-process UUID if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 if parsed := ParseMetadataUserID(uid); parsed != nil {
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
} }
} }
@ -8969,12 +8997,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 // Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { // 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil { if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
} }
} }

View File

@ -0,0 +1,133 @@
package service
import (
"context"
"errors"
)
// accountTierLevel maps an account type to a scheduling tier:
//
// 0 = subscription (OAuth / SetupToken) — tried first
// 1 = API Key — first fallback
// 2 = Bedrock — last resort
//
// Accounts with an unknown type fall into tier 0 so they participate in the
// primary selection and do not vanish silently.
func accountTierLevel(account *Account) int {
if account == nil {
return 0
}
switch account.Type {
case AccountTypeAPIKey:
return 1
case AccountTypeBedrock:
return 2
default: // OAuth, SetupToken, or unknown
return 0
}
}
// enableTierFallbackChain reports whether the cross-tier fallback chain is
// enabled in config (default false).
func (s *GatewayService) enableTierFallbackChain() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.Scheduling.EnableTierFallbackChain
}
// selectAccountWithTierFallback tries Anthropic accounts in tier order:
// tier 0 (OAuth/SetupToken subscription) → tier 1 (API Key) → tier 2 (Bedrock).
//
// Sticky sessions are honored within the chain: if the session-bound account is
// in a tier that still has capacity it is returned immediately; otherwise the
// session binding is cleared and the chain proceeds from tier 0.
func (s *GatewayService) selectAccountWithTierFallback(
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
) (*Account, error) {
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAnthropic, false)
if err != nil {
return nil, err
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// Build per-tier candidate lists (pointers into `accounts`).
const numTiers = 3
tierCandidates := [numTiers][]*Account{}
for i := range accounts {
acc := &accounts[i]
if acc.Platform != PlatformAnthropic {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !s.isAccountSchedulableForSelection(acc) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
tier := accountTierLevel(acc)
if tier < numTiers {
tierCandidates[tier] = append(tierCandidates[tier], acc)
}
}
cfg := s.schedulingConfig()
selectionMode := cfg.FallbackSelectionMode
// Check sticky session: if the bound account is a valid candidate, use it.
if sessionHash != "" && s.cache != nil {
accountID, cacheErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if cacheErr == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
for tier := 0; tier < numTiers; tier++ {
for _, acc := range tierCandidates[tier] {
if acc.ID != accountID {
continue
}
if shouldClearStickySession(acc, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
break
}
if s.isAccountSchedulableForWindowCost(ctx, acc, true) &&
s.isAccountSchedulableForRPM(ctx, acc, true) {
return acc, nil
}
}
}
}
}
}
// Try each tier in order.
for tier := 0; tier < numTiers; tier++ {
candidates := tierCandidates[tier]
if len(candidates) == 0 {
continue
}
s.sortCandidatesForFallback(candidates, false, selectionMode)
result, acquired, _ := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, false)
if acquired && result != nil {
return result.Account, nil
}
}
return nil, errors.New("no available accounts in any tier")
}

View File

@ -0,0 +1,138 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestAccountTierLevel(t *testing.T) {
require.Equal(t, 0, accountTierLevel(nil))
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeOAuth}))
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeSetupToken}))
require.Equal(t, 0, accountTierLevel(&Account{Type: "unknown"}))
require.Equal(t, 1, accountTierLevel(&Account{Type: AccountTypeAPIKey}))
require.Equal(t, 2, accountTierLevel(&Account{Type: AccountTypeBedrock}))
}
func TestGatewayService_EnableTierFallbackChain(t *testing.T) {
require.False(t, (*GatewayService)(nil).enableTierFallbackChain())
require.False(t, (&GatewayService{}).enableTierFallbackChain())
cfgOff := &config.Config{}
cfgOff.Gateway.Scheduling.EnableTierFallbackChain = false
require.False(t, (&GatewayService{cfg: cfgOff}).enableTierFallbackChain())
cfgOn := &config.Config{}
cfgOn.Gateway.Scheduling.EnableTierFallbackChain = true
require.True(t, (&GatewayService{cfg: cfgOn}).enableTierFallbackChain())
}
// TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription verifies
// that when both OAuth (subscription) and APIKey accounts are available, the
// tier-0 OAuth account is always selected first even if APIKey has higher priority.
func TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription(t *testing.T) {
ctx := context.Background()
oauthAcc := Account{ID: 91001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Priority: 5}
apiKeyAcc := Account{ID: 91002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Priority: 0}
repo := &mockAccountRepoForPlatform{
accounts: []Account{oauthAcc, apiKeyAcc},
accountsByID: map[int64]*Account{91001: &oauthAcc, 91002: &apiKeyAcc},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(91001), acc.ID, "OAuth (tier-0) account should be preferred over APIKey (tier-1)")
}
// TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey verifies
// that when the subscription tier has no schedulable accounts, the fallback
// selects an API Key account.
func TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey(t *testing.T) {
ctx := context.Background()
rateLimitedUntil := time.Now().Add(30 * time.Minute)
oauthAcc := Account{ID: 92001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
apiKeyAcc := Account{ID: 92002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
repo := &mockAccountRepoForPlatform{
accounts: []Account{oauthAcc, apiKeyAcc},
accountsByID: map[int64]*Account{92001: &oauthAcc, 92002: &apiKeyAcc},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(92002), acc.ID, "Should fall back to APIKey when OAuth is rate-limited")
}
// TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts ensures
// excluded IDs are respected across all tiers.
func TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts(t *testing.T) {
ctx := context.Background()
oauthAcc := Account{ID: 93001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}
apiKeyAcc := Account{ID: 93002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
repo := &mockAccountRepoForPlatform{
accounts: []Account{oauthAcc, apiKeyAcc},
accountsByID: map[int64]*Account{93001: &oauthAcc, 93002: &apiKeyAcc},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
excluded := map[int64]struct{}{93001: {}}
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", excluded)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(93002), acc.ID, "Excluded OAuth account should cause APIKey fallback")
}
// TestGatewayService_SelectAccountWithTierFallback_NoAccounts verifies that
// an error is returned when all tiers are empty.
func TestGatewayService_SelectAccountWithTierFallback_NoAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{accounts: nil, accountsByID: map[int64]*Account{}}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
require.Error(t, err)
require.Nil(t, acc)
}
// TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort verifies
// that Bedrock accounts are only used when subscription and API Key tiers are exhausted.
func TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort(t *testing.T) {
ctx := context.Background()
rateLimitedUntil := time.Now().Add(30 * time.Minute)
oauthAcc := Account{ID: 94001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
apiKeyAcc := Account{ID: 94002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
bedrockAcc := Account{ID: 94003, Platform: PlatformAnthropic, Type: AccountTypeBedrock, Status: StatusActive, Schedulable: true}
repo := &mockAccountRepoForPlatform{
accounts: []Account{oauthAcc, apiKeyAcc, bedrockAcc},
accountsByID: map[int64]*Account{94001: &oauthAcc, 94002: &apiKeyAcc, 94003: &bedrockAcc},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(94003), acc.ID, "Bedrock should be selected as last resort")
}

View File

@ -0,0 +1,119 @@
// Package service - HealthService 提供 liveness 与 readiness 探针。
//
// 设计动机:原有 /health 端点既被 docker-compose healthcheck 使用,又被
// dashboard 的 ops_health_score 复用——后者会触发 DB/Redis 等重操作,
// 导致探活流量污染监控指标。本服务把两类语义拆开:
// - Liveness : 仅证明进程存活(无外部依赖检查)。
// - Readiness : 检查 DB + Redis 连通,作为是否可接收流量的判断。
//
// dashboard 维度的"业务健康分"仍由 ops_health_score 计算,与本服务无关。
package service
import (
"context"
"database/sql"
"errors"
"time"
"github.com/redis/go-redis/v9"
)
// 探针默认超时。Readiness 探针需要快速失败,避免堆积。
const (
defaultReadinessTimeout = 2 * time.Second
)
// ReadinessReport 描述各依赖项的状态,便于上层暴露细节给排障。
type ReadinessReport struct {
OK bool `json:"ok"`
Details map[string]ComponentStatus `json:"details"`
Elapsed time.Duration `json:"elapsed_ms"`
}
// ComponentStatus 单个依赖项的状态。Error 字段在 OK=true 时为空。
type ComponentStatus struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
Elapsed string `json:"elapsed,omitempty"`
}
// HealthService 提供 liveness/readiness 探针。
// 字段都允许为 nil缺失的依赖在 readiness 中自动跳过,便于测试和分阶段启用。
type HealthService struct {
db *sql.DB
rdb *redis.Client
timeout time.Duration
}
// NewHealthService 构造函数。timeout<=0 时使用默认值。
func NewHealthService(db *sql.DB, rdb *redis.Client) *HealthService {
return &HealthService{
db: db,
rdb: rdb,
timeout: defaultReadinessTimeout,
}
}
// Liveness 仅返回 nil。任何调用方能拿到这个返回值就说明进程在响应请求。
// 保持无副作用、零依赖,便于 K8s livenessProbe 高频调用。
func (s *HealthService) Liveness() error {
return nil
}
// Readiness 检查所有外部依赖。任一失败则整体 OK=false。
// 单个依赖的 ctx 超时由 timeout 控制,独立计时不互相阻塞。
func (s *HealthService) Readiness(ctx context.Context) ReadinessReport {
start := time.Now()
report := ReadinessReport{
OK: true,
Details: make(map[string]ComponentStatus, 2),
}
if s.db != nil {
report.Details["database"] = s.checkDB(ctx)
if !report.Details["database"].OK {
report.OK = false
}
}
if s.rdb != nil {
report.Details["redis"] = s.checkRedis(ctx)
if !report.Details["redis"].OK {
report.OK = false
}
}
report.Elapsed = time.Since(start)
return report
}
func (s *HealthService) checkDB(parent context.Context) ComponentStatus {
ctx, cancel := context.WithTimeout(parent, s.timeout)
defer cancel()
start := time.Now()
err := s.db.PingContext(ctx)
status := ComponentStatus{Elapsed: time.Since(start).String()}
if err != nil {
status.Error = err.Error()
return status
}
status.OK = true
return status
}
func (s *HealthService) checkRedis(parent context.Context) ComponentStatus {
ctx, cancel := context.WithTimeout(parent, s.timeout)
defer cancel()
start := time.Now()
pong, err := s.rdb.Ping(ctx).Result()
status := ComponentStatus{Elapsed: time.Since(start).String()}
if err != nil {
status.Error = err.Error()
return status
}
if pong != "PONG" {
status.Error = errors.New("unexpected redis ping response: " + pong).Error()
return status
}
status.OK = true
return status
}

View File

@ -0,0 +1,93 @@
package service
import (
"context"
"database/sql"
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestHealthService_Liveness_AlwaysOK(t *testing.T) {
s := NewHealthService(nil, nil)
require.NoError(t, s.Liveness())
}
func TestHealthService_Readiness_AllNilReturnsOK(t *testing.T) {
// 当所有依赖都为 nil 时(早期启动或 unit testreadiness 应直接 OK。
s := NewHealthService(nil, nil)
report := s.Readiness(context.Background())
require.True(t, report.OK)
require.Empty(t, report.Details)
}
func TestHealthService_Readiness_DBPingFails(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()
mock.ExpectPing().WillReturnError(errors.New("connection refused"))
s := NewHealthService(db, nil)
report := s.Readiness(context.Background())
require.False(t, report.OK)
require.Contains(t, report.Details, "database")
require.False(t, report.Details["database"].OK)
require.Contains(t, report.Details["database"].Error, "connection refused")
}
func TestHealthService_Readiness_DBOK(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()
mock.ExpectPing()
s := NewHealthService(db, nil)
report := s.Readiness(context.Background())
require.True(t, report.OK)
require.True(t, report.Details["database"].OK)
}
func TestHealthService_Readiness_RedisFails(t *testing.T) {
// 指向一个不可达端口让 redis ping 立刻失败。
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 200 * time.Millisecond,
ReadTimeout: 200 * time.Millisecond,
})
defer rdb.Close()
s := NewHealthService(nil, rdb)
s.timeout = 500 * time.Millisecond
report := s.Readiness(context.Background())
require.False(t, report.OK)
require.Contains(t, report.Details, "redis")
require.False(t, report.Details["redis"].OK)
}
func TestHealthService_Readiness_PerComponentTimeout(t *testing.T) {
// 验证 readiness 在超时时不会无限挂住。
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()
mock.ExpectPing().WillDelayFor(2 * time.Second)
s := NewHealthService(db, nil)
s.timeout = 100 * time.Millisecond
start := time.Now()
report := s.Readiness(context.Background())
elapsed := time.Since(start)
require.Less(t, elapsed, 1*time.Second, "readiness should respect per-component timeout")
require.False(t, report.OK)
require.NotEmpty(t, report.Details["database"].Error, "timeout should propagate as an error")
}
// 抑制未使用包警告database/sql 在签名里使用)。
var _ = sql.ErrNoRows

View File

@ -8,6 +8,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"golang.org/x/sync/singleflight"
) )
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 // OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
@ -30,12 +32,22 @@ type OAuthRefreshResult struct {
} }
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口 // OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑 // 封装分布式锁、进程内去重singleflight、DB 重读、已刷新检查、竞争恢复等通用逻辑
//
// 双层去重设计:
// 1. 进程内 singleflight合并同一 cacheKey 的并发调用(避免 100 个 goroutine
// 都去抢同一把分布式锁、都重读一次 DB
// 2. 跨进程分布式锁Redis保证集群范围内只有一个 worker 真正发起 OAuth
// 刷新请求。
//
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT跨进程锁仍是必需的
// singleflight 解决不了多 pod 同时刷新。
type OAuthRefreshAPI struct { type OAuthRefreshAPI struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无分布式锁 tokenCache GeminiTokenCache // 可选nil = 无分布式锁
lockTTL time.Duration lockTTL time.Duration
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
sf singleflight.Group
} }
// NewOAuthRefreshAPI 创建统一刷新 API // NewOAuthRefreshAPI 创建统一刷新 API
@ -63,15 +75,19 @@ func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
return mu return mu
} }
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
//
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
// //
// 流程: // 流程:
// 1. 获取分布式锁 // 1. singleflight 合并同 cacheKey 并发调用
// 2. 从 DB 重读最新 account防止使用过时的 refresh_token // 2. 获取分布式锁(跨进程)
// 3. 二次检查是否仍需刷新 // 3. 从 DB 重读最新 account防止使用过时的 refresh_token
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 // 4. 二次检查是否仍需刷新
// 5. 设置 _token_version + 更新 DB // 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 6. 释放锁 // 6. 设置 _token_version + 更新 DB
// 7. 释放锁
func (api *OAuthRefreshAPI) RefreshIfNeeded( func (api *OAuthRefreshAPI) RefreshIfNeeded(
ctx context.Context, ctx context.Context,
account *Account, account *Account,
@ -80,11 +96,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
) (*OAuthRefreshResult, error) { ) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account) cacheKey := executor.CacheKey(account)
// 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争) // singleflight key 同时区分 cacheKey 和 refreshWindow
localMu := api.getLocalLock(cacheKey) // 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh
localMu.Lock() // 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
defer localMu.Unlock() sfKey := cacheKey + "|" + refreshWindow.String()
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
})
if err != nil {
return nil, err
}
result, _ := v.(*OAuthRefreshResult)
return result, nil
}
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
func (api *OAuthRefreshAPI) refreshOnce(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
cacheKey string,
) (*OAuthRefreshResult, error) {
// 1. 获取分布式锁 // 1. 获取分布式锁
lockAcquired := false lockAcquired := false
if api.tokenCache != nil { if api.tokenCache != nil {

View File

@ -0,0 +1,160 @@
//go:build unit
package service
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
type blockingExecutor struct {
refreshAPIExecutorStub
release chan struct{}
concurrent int32 // 当前正在 Refresh 的 goroutine 数
maxObserved int32 // 观察到的最大并发数
calls int32
}
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
atomic.AddInt32(&e.calls, 1)
c := atomic.AddInt32(&e.concurrent, 1)
for {
old := atomic.LoadInt32(&e.maxObserved)
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
break
}
}
defer atomic.AddInt32(&e.concurrent, -1)
<-e.release
return e.credentials, e.err
}
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
// 同一 cacheKey 同时进入 N 个 goroutine应只触发 1 次 executor.Refresh。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
const callers = 20
results := make([]*OAuthRefreshResult, callers)
errs := make([]error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for i := 0; i < callers; i++ {
i := i
go func() {
defer wg.Done()
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
results[i] = r
errs[i] = err
}()
}
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
close(exec.release)
wg.Wait()
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
// 所有 caller 应拿到等价结果不必同实例singleflight Shared 标志会让多个 caller 共享)。
for i := 0; i < callers; i++ {
require.NoError(t, errs[i])
require.NotNil(t, results[i])
require.True(t, results[i].Refreshed)
}
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
// 不同账号有不同 cacheKey应能并行刷新而非互相阻塞。
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
platform := "p" + string(rune('a'+i))
wg.Add(1)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
}()
}
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
close(exec.release)
wg.Wait()
}
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
// 同 cacheKey 但不同 refreshWindow前台短窗口 vs 后台长窗口)应分开判断
// NeedsRefresh避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
cache := &refreshAPICacheStub{lockResult: true}
exec := &blockingExecutor{
refreshAPIExecutorStub: refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new"},
},
release: make(chan struct{}),
}
api := NewOAuthRefreshAPI(repo, cache)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
}()
go func() {
defer wg.Done()
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
}()
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
close(exec.release)
wg.Wait()
}

View File

@ -730,12 +730,18 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
} }
quotaFactor := item.account.GetQuotaRemainingFraction()
item.score = weights.Priority*priorityFactor + item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor + weights.Load*loadFactor +
weights.Queue*queueFactor + weights.Queue*queueFactor +
weights.ErrorRate*errorFactor + weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor weights.TTFT*ttftFactor +
weights.Quota*quotaFactor
}
if s.service.openAIWSP2CEnabled() {
return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew)
} }
} }
@ -1193,6 +1199,7 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
Quota: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota,
} }
} }
return GatewayOpenAIWSSchedulerScoreWeightsView{ return GatewayOpenAIWSSchedulerScoreWeightsView{
@ -1201,15 +1208,21 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
Queue: 0.7, Queue: 0.7,
ErrorRate: 0.8, ErrorRate: 0.8,
TTFT: 0.5, TTFT: 0.5,
Quota: 0.0,
} }
} }
func (s *OpenAIGatewayService) openAIWSP2CEnabled() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EnableP2CScheduling
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct { type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64 Priority float64
Load float64 Load float64
Queue float64 Queue float64
ErrorRate float64 ErrorRate float64
TTFT float64 TTFT float64
Quota float64
} }
func clamp01(value float64) float64 { func clamp01(value float64) float64 {
@ -1223,6 +1236,94 @@ func clamp01(value float64) float64 {
} }
} }
// selectByPowerOfTwo implements Power-of-Two-Choices (P2C): sample 2 random
// candidates and attempt the better-scored one first, then the other.
// This gives O(1) selection with load distribution comparable to top-K when N is large.
func (s *defaultOpenAIAccountScheduler) selectByPowerOfTwo(
ctx context.Context,
req OpenAIAccountScheduleRequest,
candidates []openAIAccountCandidateScore,
loadSkew float64,
) (*AccountSelectionResult, int, int, float64, error) {
n := len(candidates)
if n == 0 {
return nil, 0, 0, loadSkew, ErrNoAvailableAccounts
}
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
// Pick two distinct random indices.
idxA := int(rng.nextUint64() % uint64(n))
idxB := idxA
if n > 1 {
for idxB == idxA {
idxB = int(rng.nextUint64() % uint64(n))
}
}
// Order: better candidate first.
first, second := candidates[idxA], candidates[idxB]
if isOpenAIAccountCandidateBetter(second, first) {
first, second = second, first
}
tryAcquire := func(c openAIAccountCandidateScore) (*AccountSelectionResult, bool, error) {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
return nil, false, nil
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, req.RequireCompact)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
return nil, false, nil
}
result, err := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err != nil {
return nil, false, err
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true, nil
}
return nil, false, nil
}
for _, c := range []openAIAccountCandidateScore{first, second} {
result, ok, err := tryAcquire(c)
if err != nil {
return nil, n, 2, loadSkew, err
}
if ok {
return result, n, 2, loadSkew, nil
}
}
// Both slots busy — return wait plan on the better candidate.
cfg := s.service.schedulingConfig()
for _, c := range []openAIAccountCandidateScore{first, second} {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
continue
}
return &AccountSelectionResult{
Account: fresh,
WaitPlan: &AccountWaitPlan{
AccountID: fresh.ID,
MaxConcurrency: fresh.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, n, 2, loadSkew, nil
}
return nil, n, 2, loadSkew, ErrNoAvailableAccounts
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 { if count <= 1 {
return 0 return 0

View File

@ -1448,3 +1448,129 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
func int64PtrForTest(v int64) *int64 { func int64PtrForTest(v int64) *int64 {
return &v return &v
} }
func TestAccount_GetQuotaRemainingFraction(t *testing.T) {
// No limit configured → always 1.0 (unlimited)
noLimit := &Account{}
require.Equal(t, 1.0, noLimit.GetQuotaRemainingFraction())
// 50% used
half := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 50.0}}
require.InDelta(t, 0.5, half.GetQuotaRemainingFraction(), 1e-9)
// Fully exhausted
full := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 100.0}}
require.Equal(t, 0.0, full.GetQuotaRemainingFraction())
// Over limit → clamp to 0
over := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 150.0}}
require.Equal(t, 0.0, over.GetQuotaRemainingFraction())
// Fresh (0 used)
fresh := &Account{Extra: map[string]any{"quota_limit": 200.0, "quota_used": 0.0}}
require.Equal(t, 1.0, fresh.GetQuotaRemainingFraction())
}
func TestOpenAIGatewayService_P2CEnabled(t *testing.T) {
require.False(t, (*OpenAIGatewayService)(nil).openAIWSP2CEnabled())
require.False(t, (&OpenAIGatewayService{}).openAIWSP2CEnabled())
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.EnableP2CScheduling = false
require.False(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
require.True(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
}
func TestOpenAIGatewayService_SchedulerWeights_QuotaField(t *testing.T) {
// Default weights: Quota is 0 (disabled by default)
svc := &OpenAIGatewayService{}
weights := svc.openAIWSSchedulerWeights()
require.Equal(t, 0.0, weights.Quota)
// Config-driven quota weight
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota = 0.4
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
require.Equal(t, 0.4, svcWithCfg.openAIWSSchedulerWeights().Quota)
}
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_SingleCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(99001)
account := &Account{ID: 71001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{account},
accountsByID: map[int64]*Account{71001: account},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
cfg.Gateway.OpenAIWS.LBTopK = 5
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
cfg: cfg,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.Equal(t, int64(71001), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_PicksBetterCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(99002)
// Account A has low priority (better), B has high priority (worse).
// With P2C enabled and a deterministic seed, we should always get a valid selection.
accountA := &Account{ID: 72001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
accountB := &Account{ID: 72002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 10}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{accountA, accountB},
accountsByID: map[int64]*Account{72001: accountA, 72002: accountB},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*accountA, *accountB}},
cfg: cfg,
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
// Either account is valid; just verify we got a schedulable one.
require.True(t, selection.Account.ID == 72001 || selection.Account.ID == 72002)
}
func TestDefaultOpenAIAccountScheduler_QuotaFactorInfluencesScore(t *testing.T) {
// Verify that quota weight affects scoring by checking GetQuotaRemainingFraction is used.
// Account with high remaining quota should score higher when quota weight > 0.
highQuota := &Account{
ID: 73001, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 10.0}, // 90% remaining
}
lowQuota := &Account{
ID: 73002, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 90.0}, // 10% remaining
}
require.InDelta(t, 0.9, highQuota.GetQuotaRemainingFraction(), 1e-9)
require.InDelta(t, 0.1, lowQuota.GetQuotaRemainingFraction(), 1e-9)
// With quota weight = 1.0 and all other weights = 0, high-quota account should win.
// We verify the score ordering directly using isOpenAIAccountCandidateBetter.
highScore := openAIAccountCandidateScore{account: highQuota, score: 0.9}
lowScore := openAIAccountCandidateScore{account: lowQuota, score: 0.1}
require.True(t, isOpenAIAccountCandidateBetter(highScore, lowScore))
require.False(t, isOpenAIAccountCandidateBetter(lowScore, highScore))
}

View File

@ -0,0 +1,75 @@
package service
import (
"sync"
"sync/atomic"
"time"
)
const requestEventBufSize = 64
// RequestEvent is published for every gateway dispatch completion.
type RequestEvent struct {
Timestamp time.Time `json:"timestamp"`
Method string `json:"method"`
Path string `json:"path"`
Model string `json:"model"`
AccountID int64 `json:"account_id"`
// Status is "success", "error", or "rate_limited".
Status string `json:"status"`
LatencyMS int64 `json:"latency_ms"`
}
// RequestEventBus is a fan-out hub for real-time request events.
// Publishers call Publish; subscribers call Subscribe/Unsubscribe.
// Each subscriber gets its own buffered channel. If the buffer is full
// the event is dropped for that subscriber (non-blocking publish).
type RequestEventBus struct {
mu sync.RWMutex
subscribers map[uint64]chan RequestEvent
nextID atomic.Uint64
}
func NewRequestEventBus() *RequestEventBus {
return &RequestEventBus{
subscribers: make(map[uint64]chan RequestEvent),
}
}
// Subscribe registers a new subscriber and returns its ID and a receive-only channel.
func (b *RequestEventBus) Subscribe() (uint64, <-chan RequestEvent) {
id := b.nextID.Add(1)
ch := make(chan RequestEvent, requestEventBufSize)
b.mu.Lock()
b.subscribers[id] = ch
b.mu.Unlock()
return id, ch
}
// Unsubscribe removes a subscriber and closes its channel.
func (b *RequestEventBus) Unsubscribe(id uint64) {
b.mu.Lock()
ch, ok := b.subscribers[id]
if ok {
delete(b.subscribers, id)
}
b.mu.Unlock()
if ok {
close(ch)
}
}
// Publish sends an event to all current subscribers without blocking.
func (b *RequestEventBus) Publish(e RequestEvent) {
if b == nil {
return
}
b.mu.RLock()
defer b.mu.RUnlock()
for _, ch := range b.subscribers {
select {
case ch <- e:
default:
}
}
}

View File

@ -0,0 +1,100 @@
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRequestEventBus_PublishToSubscriber(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
defer bus.Unsubscribe(id)
evt := RequestEvent{Model: "claude-3", Status: "success", LatencyMS: 100}
bus.Publish(evt)
select {
case got := <-ch:
assert.Equal(t, evt, got)
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
}
func TestRequestEventBus_MultipleSubscribers(t *testing.T) {
bus := NewRequestEventBus()
id1, ch1 := bus.Subscribe()
id2, ch2 := bus.Subscribe()
defer bus.Unsubscribe(id1)
defer bus.Unsubscribe(id2)
evt := RequestEvent{Model: "claude-3", Status: "error"}
bus.Publish(evt)
for _, ch := range []<-chan RequestEvent{ch1, ch2} {
select {
case got := <-ch:
assert.Equal(t, evt, got)
case <-time.After(time.Second):
t.Fatal("timed out waiting for event on one subscriber")
}
}
}
func TestRequestEventBus_UnsubscribeClosesChannel(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
bus.Unsubscribe(id)
// Channel should be closed.
_, ok := <-ch
assert.False(t, ok, "channel should be closed after Unsubscribe")
}
func TestRequestEventBus_UnsubscribedMissesEvents(t *testing.T) {
bus := NewRequestEventBus()
id, _ := bus.Subscribe()
bus.Unsubscribe(id)
// Publish after unsubscribe should not panic.
require.NotPanics(t, func() {
bus.Publish(RequestEvent{Model: "test"})
})
}
func TestRequestEventBus_DropWhenFull(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
defer bus.Unsubscribe(id)
// Fill the buffer then publish one more — should drop, not block.
evt := RequestEvent{Model: "model", Status: "success"}
for i := 0; i < requestEventBufSize; i++ {
bus.Publish(evt)
}
// This publish should return immediately (dropped).
done := make(chan struct{})
go func() {
bus.Publish(evt)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("Publish blocked when buffer was full")
}
assert.Len(t, ch, requestEventBufSize)
}
func TestRequestEventBus_NilSafePublish(t *testing.T) {
var bus *RequestEventBus
require.NotPanics(t, func() {
bus.Publish(RequestEvent{Model: "test"})
})
}

View File

@ -0,0 +1,120 @@
package service
import (
"context"
"errors"
"math"
"sync"
"time"
)
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
// When an account's RPM budget is exhausted, callers can wait up to a configured
// deadline instead of receiving an immediate 429. The bucket refills continuously
// at rpm/60 tokens per second so requests are distributed evenly over time.
type RPMTokenBucketService struct {
buckets sync.Map // map[int64]*rpmEntry
}
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
func NewRPMTokenBucketService() *RPMTokenBucketService {
return &RPMTokenBucketService{}
}
type rpmEntry struct {
bucket *tokenBucket
rpm int
}
// getBucket returns (or creates) the token bucket for accountID.
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
if v, ok := s.buckets.Load(accountID); ok {
e := v.(*rpmEntry)
if e.rpm == rpm {
return e.bucket
}
// RPM limit changed — replace with a fresh bucket.
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
s.buckets.Store(accountID, fresh)
return fresh.bucket
}
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
actual, _ := s.buckets.LoadOrStore(accountID, entry)
return actual.(*rpmEntry).bucket
}
// AcquireWithWait attempts to consume one token for the given account.
// It blocks up to maxWait for a token to become available.
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
// or ctx.Err() if the context is cancelled.
// If rpm <= 0 the call returns immediately with nil.
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
if rpm <= 0 {
return nil
}
bucket := s.getBucket(accountID, rpm)
deadline := time.Now().Add(maxWait)
for {
ok, waitDur := bucket.tryAcquire()
if ok {
return nil
}
remaining := time.Until(deadline)
if remaining <= 0 || waitDur > remaining {
return ErrRPMWaitTimeout
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(waitDur):
// token may be available now; retry
}
}
}
// tokenBucket is a continuous-refill token bucket for a single account.
type tokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
rateSec float64 // tokens refilled per second = rpm / 60
lastFill time.Time
}
func newTokenBucket(rpm int) *tokenBucket {
max := float64(rpm)
return &tokenBucket{
tokens: max,
maxTokens: max,
rateSec: float64(rpm) / 60.0,
lastFill: time.Now(),
}
}
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
elapsed := now.Sub(b.lastFill).Seconds()
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
b.lastFill = now
if b.tokens >= 1.0 {
b.tokens -= 1.0
return true, 0
}
deficit := 1.0 - b.tokens
waitSecs := deficit / b.rateSec
return false, time.Duration(waitSecs * float64(time.Second))
}

View File

@ -0,0 +1,108 @@
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately.
for i := 0; i < 60; i++ {
err := svc.AcquireWithWait(ctx, 1, 60, 0)
require.NoError(t, err, "call %d should succeed immediately", i+1)
}
}
func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) {
svc := NewRPMTokenBucketService()
err := svc.AcquireWithWait(context.Background(), 42, 0, 0)
assert.NoError(t, err)
}
func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// rpm=1 → 1 token/minute. One call drains the bucket.
err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second)
require.NoError(t, err, "first call should succeed")
// Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms.
start := time.Now()
err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond)
elapsed := time.Since(start)
assert.ErrorIs(t, err, ErrRPMWaitTimeout)
assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block")
}
func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully.
for i := 0; i < 120; i++ {
require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0))
}
// Next call needs to wait ~500ms for the next token. Give it 2s.
start := time.Now()
err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second)
elapsed := time.Since(start)
require.NoError(t, err, "should succeed after waiting for refill")
assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited")
assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long")
}
func TestRPMTokenBucket_ContextCancellation(t *testing.T) {
svc := NewRPMTokenBucketService()
// rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining.
// maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms).
// Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first.
for i := 0; i < 120; i++ {
require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0))
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
start := time.Now()
err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second)
elapsed := time.Since(start)
assert.ErrorIs(t, err, context.Canceled)
assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly")
}
func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Drain account 1 (rpm=1).
require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0))
// Account 2 has its own bucket and should succeed immediately.
err := svc.AcquireWithWait(ctx, 2, 1, 0)
assert.NoError(t, err, "different account should have an independent bucket")
}
func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Create bucket with rpm=1 and drain it.
require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0))
// Bucket now empty with rpm=1.
// Changing RPM to 60 should reset the bucket to full (60 tokens).
err := svc.AcquireWithWait(ctx, 10, 60, 0)
assert.NoError(t, err, "new RPM should cause bucket recreation")
}

View File

@ -426,6 +426,8 @@ var ProviderSet = wire.NewSet(
ProvideBillingCacheService, ProvideBillingCacheService,
NewAnnouncementService, NewAnnouncementService,
NewAdminService, NewAdminService,
NewRPMTokenBucketService,
NewRequestEventBus,
NewGatewayService, NewGatewayService,
NewOpenAIGatewayService, NewOpenAIGatewayService,
NewOAuthService, NewOAuthService,
@ -448,6 +450,7 @@ var ProviderSet = wire.NewSet(
ProvideSettingService, ProvideSettingService,
NewDataManagementService, NewDataManagementService,
ProvideBackupService, ProvideBackupService,
NewHealthService,
ProvideOpsSystemLogSink, ProvideOpsSystemLogSink,
NewOpsService, NewOpsService,
ProvideOpsMetricsCollector, ProvideOpsMetricsCollector,

View File

@ -120,7 +120,7 @@ services:
networks: networks:
- sub2api-network - sub2api-network
healthcheck: healthcheck:
test: [ "CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health" ] test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/ready"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3