feat: merge feat/omniroute-ideas — P2C scheduler, quota scoring, tier fallback
This commit is contained in:
commit
fdd2d08a4d
@ -188,7 +188,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
@ -203,7 +204,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
requestEventBus := service.NewRequestEventBus()
|
||||||
|
opsHandler := admin.NewOpsHandler(opsService, requestEventBus)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||||
@ -244,7 +246,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
@ -257,7 +259,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
|
healthService := service.NewHealthService(db, redisClient)
|
||||||
|
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient)
|
||||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||||
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||||
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||||
|
|||||||
@ -691,6 +691,14 @@ type GatewayConfig struct {
|
|||||||
// UserMessageQueue: 用户消息串行队列配置
|
// UserMessageQueue: 用户消息串行队列配置
|
||||||
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
|
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
|
||||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||||
|
|
||||||
|
// RPMSmoothing: RPM 令牌桶平滑配置
|
||||||
|
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
||||||
|
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
||||||
|
|
||||||
|
// ContextCompression: 主动上下文压缩配置
|
||||||
|
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
|
||||||
|
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayAntigravityLSWorkerConfig struct {
|
type GatewayAntigravityLSWorkerConfig struct {
|
||||||
@ -745,6 +753,37 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RPMSmoothingConfig RPM 令牌桶平滑配置
|
||||||
|
type RPMSmoothingConfig struct {
|
||||||
|
// Enabled: 是否启用 RPM 令牌桶平滑(默认 false)
|
||||||
|
// 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429(默认 5000)
|
||||||
|
MaxWaitMS int `mapstructure:"max_wait_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxWait returns the configured wait duration, defaulting to 5s.
|
||||||
|
func (c *RPMSmoothingConfig) MaxWait() time.Duration {
|
||||||
|
if c.MaxWaitMS <= 0 {
|
||||||
|
return 5 * time.Second
|
||||||
|
}
|
||||||
|
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextCompressionConfig 主动上下文压缩配置
|
||||||
|
type ContextCompressionConfig struct {
|
||||||
|
// MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000)
|
||||||
|
MaxTokens int `mapstructure:"max_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
|
||||||
|
func (c *ContextCompressionConfig) GetMaxTokens() int {
|
||||||
|
if c.MaxTokens <= 0 {
|
||||||
|
return 190_000
|
||||||
|
}
|
||||||
|
return c.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||||
type GatewayOpenAIWSConfig struct {
|
type GatewayOpenAIWSConfig struct {
|
||||||
@ -828,6 +867,8 @@ type GatewayOpenAIWSConfig struct {
|
|||||||
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
|
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
|
||||||
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
|
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
|
||||||
|
|
||||||
|
// EnableP2CScheduling: 启用 Power-of-Two-Choices 调度(默认 false,使用 top-K 加权随机)
|
||||||
|
EnableP2CScheduling bool `mapstructure:"enable_p2c_scheduling"`
|
||||||
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
|
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -838,6 +879,8 @@ type GatewayOpenAIWSSchedulerScoreWeights struct {
|
|||||||
Queue float64 `mapstructure:"queue"`
|
Queue float64 `mapstructure:"queue"`
|
||||||
ErrorRate float64 `mapstructure:"error_rate"`
|
ErrorRate float64 `mapstructure:"error_rate"`
|
||||||
TTFT float64 `mapstructure:"ttft"`
|
TTFT float64 `mapstructure:"ttft"`
|
||||||
|
// Quota: 剩余配额比例权重(0 表示不参与打分)
|
||||||
|
Quota float64 `mapstructure:"quota"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||||
@ -992,6 +1035,10 @@ type GatewaySchedulingConfig struct {
|
|||||||
// 全量重建周期配置
|
// 全量重建周期配置
|
||||||
// 全量重建周期(秒),0 表示禁用
|
// 全量重建周期(秒),0 表示禁用
|
||||||
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
||||||
|
|
||||||
|
// EnableTierFallbackChain: 启用跨档降级链(订阅 → API Key → Bedrock),默认 false
|
||||||
|
// 仅对 Anthropic 平台生效;启用后账号按类型分层,优先使用订阅账号,依次降级。
|
||||||
|
EnableTierFallbackChain bool `mapstructure:"enable_tier_fallback_chain"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerConfig) Address() string {
|
func (s *ServerConfig) Address() string {
|
||||||
@ -2504,7 +2551,8 @@ func (c *Config) Validate() error {
|
|||||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
||||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
||||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
|
||||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 ||
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Quota < 0 {
|
||||||
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
|
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
|
||||||
}
|
}
|
||||||
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
|
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
|
||||||
|
|||||||
@ -16,7 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type OpsHandler struct {
|
type OpsHandler struct {
|
||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
|
requestEventBus *service.RequestEventBus
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetErrorLogByID returns ops error log detail.
|
// GetErrorLogByID returns ops error log detail.
|
||||||
@ -70,8 +71,8 @@ func parseOpsViewParam(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler {
|
||||||
return &OpsHandler{opsService: opsService}
|
return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetErrorLogs lists ops error logs.
|
// GetErrorLogs lists ops error logs.
|
||||||
|
|||||||
@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||||
r := newOpsRuntimeRouter(h, false)
|
r := newOpsRuntimeRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||||
r := newOpsRuntimeRouter(h, false)
|
r := newOpsRuntimeRouter(h, false)
|
||||||
|
|
||||||
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||||
@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
||||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||||
r := newOpsRuntimeRouter(h, true)
|
r := newOpsRuntimeRouter(h, true)
|
||||||
|
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
|
|||||||
@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||||
h := NewOpsHandler(nil)
|
h := NewOpsHandler(nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
|||||||
|
|
||||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, true)
|
r := newOpsSystemLogTestRouter(h, true)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
|||||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||||
sink := service.NewOpsSystemLogSink(nil)
|
sink := service.NewOpsSystemLogSink(nil)
|
||||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||||
h := NewOpsHandler(svc)
|
h := NewOpsHandler(svc, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||||
h := NewOpsHandler(nil)
|
h := NewOpsHandler(nil, nil)
|
||||||
r := newOpsSystemLogTestRouter(h, false)
|
r := newOpsSystemLogTestRouter(h, false)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T
|
|||||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||||
Ops: config.OpsConfig{Enabled: false},
|
Ops: config.OpsConfig{Enabled: false},
|
||||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
h = NewOpsHandler(svc)
|
h = NewOpsHandler(svc, nil)
|
||||||
r = newOpsSystemLogTestRouter(h, false)
|
r = newOpsSystemLogTestRouter(h, false)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||||
|
|||||||
198
backend/internal/handler/admin/ops_ws_requests_handler.go
Normal file
198
backend/internal/handler/admin/ops_ws_requests_handler.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestStreamWSMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data service.RequestEvent `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestStreamWSHandler streams real-time request events to WebSocket clients.
|
||||||
|
// GET /api/v1/admin/ops/ws/requests
|
||||||
|
//
|
||||||
|
// Each connected client receives a JSON message per gateway dispatch:
|
||||||
|
//
|
||||||
|
// {"type":"request_event","data":{"timestamp":...,"method":"POST","path":"/v1/messages",
|
||||||
|
// "model":"claude-3-5-sonnet-20241022","account_id":42,"status":"success","latency_ms":1230}}
|
||||||
|
func (h *OpsHandler) RequestStreamWSHandler(c *gin.Context) {
|
||||||
|
clientIP := requestClientIP(c.Request)
|
||||||
|
|
||||||
|
if h == nil || h.opsService == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.requestEventBus == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "request event bus not initialized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if wsConnCount.Add(-1) == 0 {
|
||||||
|
scheduleQPSWSIdleStop()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||||
|
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] per-ip limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer releaseOpsWSIPSlot(clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] upgrade failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
handleRequestStreamWebSocket(c.Request.Context(), conn, h.requestEventBus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRequestStreamWebSocket(parentCtx context.Context, conn *websocket.Conn, bus *service.RequestEventBus) {
|
||||||
|
if conn == nil || bus == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
subID, eventCh := bus.Subscribe()
|
||||||
|
defer bus.Unsubscribe(subID)
|
||||||
|
|
||||||
|
var closeOnce sync.Once
|
||||||
|
closeConn := func() {
|
||||||
|
closeOnce.Do(func() { _ = conn.Close() })
|
||||||
|
}
|
||||||
|
|
||||||
|
closeFrameCh := make(chan []byte, 1)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] set read deadline failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetPongHandler(func(string) error {
|
||||||
|
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
|
||||||
|
})
|
||||||
|
conn.SetCloseHandler(func(code int, text string) error {
|
||||||
|
select {
|
||||||
|
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, _, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] read failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
pingTicker := time.NewTicker(qpsWSPingInterval)
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
|
||||||
|
writeWithTimeout := func(messageType int, data []byte) error {
|
||||||
|
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendClose := func(closeFrame []byte) {
|
||||||
|
if closeFrame == nil {
|
||||||
|
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
|
||||||
|
}
|
||||||
|
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case evt, ok := <-eventCh:
|
||||||
|
if !ok {
|
||||||
|
// channel closed by Unsubscribe
|
||||||
|
sendClose(nil)
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg, err := json.Marshal(requestStreamWSMessage{Type: "request_event", Data: evt})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] write failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-pingTicker.C:
|
||||||
|
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] ping failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case closeFrame := <-closeFrameCh:
|
||||||
|
sendClose(closeFrame)
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
var closeFrame []byte
|
||||||
|
select {
|
||||||
|
case closeFrame = <-closeFrameCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
sendClose(closeFrame)
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -48,6 +48,7 @@ type GatewayHandler struct {
|
|||||||
errorPassthroughService *service.ErrorPassthroughService
|
errorPassthroughService *service.ErrorPassthroughService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
userMsgQueueHelper *UserMsgQueueHelper
|
userMsgQueueHelper *UserMsgQueueHelper
|
||||||
|
requestEventBus *service.RequestEventBus
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
maxAccountSwitchesGemini int
|
maxAccountSwitchesGemini int
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@ -70,6 +71,7 @@ func NewGatewayHandler(
|
|||||||
userMsgQueueService *service.UserMessageQueueService,
|
userMsgQueueService *service.UserMessageQueueService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
|
requestEventBus *service.RequestEventBus,
|
||||||
) *GatewayHandler {
|
) *GatewayHandler {
|
||||||
pingInterval := time.Duration(0)
|
pingInterval := time.Duration(0)
|
||||||
maxAccountSwitches := 10
|
maxAccountSwitches := 10
|
||||||
@ -103,6 +105,7 @@ func NewGatewayHandler(
|
|||||||
errorPassthroughService: errorPassthroughService,
|
errorPassthroughService: errorPassthroughService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
userMsgQueueHelper: umqHelper,
|
userMsgQueueHelper: umqHelper,
|
||||||
|
requestEventBus: requestEventBus,
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@ -113,6 +116,7 @@ func NewGatewayHandler(
|
|||||||
// Messages handles Claude API compatible messages endpoint
|
// Messages handles Claude API compatible messages endpoint
|
||||||
// POST /v1/messages
|
// POST /v1/messages
|
||||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||||
|
reqStartTime := time.Now()
|
||||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -164,6 +168,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 解析渠道级模型映射
|
// 解析渠道级模型映射
|
||||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
|
||||||
|
// 实时请求查看器:记录每次请求的结果(账号、模型、状态、延迟)
|
||||||
|
var (
|
||||||
|
reqEventAccountID int64
|
||||||
|
reqEventStatus = "error"
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if h.requestEventBus != nil {
|
||||||
|
h.requestEventBus.Publish(service.RequestEvent{
|
||||||
|
Timestamp: reqStartTime,
|
||||||
|
Method: c.Request.Method,
|
||||||
|
Path: c.FullPath(),
|
||||||
|
Model: reqModel,
|
||||||
|
AccountID: reqEventAccountID,
|
||||||
|
Status: reqEventStatus,
|
||||||
|
LatencyMS: time.Since(reqStartTime).Milliseconds(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||||
@ -406,6 +429,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||||
|
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||||
|
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||||
|
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||||
|
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||||
|
rpmCancel()
|
||||||
|
if rpmErr != nil {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
reqEventAccountID = account.ID
|
||||||
|
reqEventStatus = "rate_limited"
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
@ -473,6 +513,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 实时请求查看器:标记 Gemini 路径成功
|
||||||
|
reqEventAccountID = account.ID
|
||||||
|
reqEventStatus = "success"
|
||||||
|
|
||||||
// RPM 计数递增(Forward 成功后)
|
// RPM 计数递增(Forward 成功后)
|
||||||
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
||||||
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
||||||
@ -650,6 +694,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||||
|
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||||
|
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||||
|
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||||
|
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||||
|
rpmCancel()
|
||||||
|
if rpmErr != nil {
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
reqEventAccountID = account.ID
|
||||||
|
reqEventStatus = "rate_limited"
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
@ -850,6 +911,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 实时请求查看器:标记 Anthropic 路径成功
|
||||||
|
reqEventAccountID = account.ID
|
||||||
|
reqEventStatus = "success"
|
||||||
|
|
||||||
// RPM 计数递增(Forward 成功后)
|
// RPM 计数递增(Forward 成功后)
|
||||||
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
||||||
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
entsql "entgo.io/ent/dialect/sql"
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
"entgo.io/ent/dialect/sql/sqljson"
|
"entgo.io/ent/dialect/sql/sqljson"
|
||||||
@ -49,6 +50,13 @@ type accountRepository struct {
|
|||||||
// Used to proactively sync account snapshot to cache when status changes,
|
// Used to proactively sync account snapshot to cache when status changes,
|
||||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||||
schedulerCache service.SchedulerCache
|
schedulerCache service.SchedulerCache
|
||||||
|
|
||||||
|
// tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。
|
||||||
|
// 上游 401/限流爆发时,N 个 in-flight 请求会同时调用此方法;底层 SQL
|
||||||
|
// 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row,但 N 次
|
||||||
|
// SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight
|
||||||
|
// 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。
|
||||||
|
tempUnschedSF singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||||
@ -1124,6 +1132,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
// 进程内合并并发调用:key 包含 until 和 reason,确保不同窗口/原因独立去重。
|
||||||
|
// until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致;
|
||||||
|
// 哪怕略有偏差,SQL 的 (existing < new) 条件保证语义安全。
|
||||||
|
sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason
|
||||||
|
_, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) {
|
||||||
|
return nil, r.setTempUnschedulableOnce(ctx, id, until, reason)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
_, err := r.sql.ExecContext(ctx, `
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
UPDATE accounts
|
UPDATE accounts
|
||||||
SET temp_unschedulable_until = $1,
|
SET temp_unschedulable_until = $1,
|
||||||
|
|||||||
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。
|
||||||
|
// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在
|
||||||
|
// singleflight 的同一窗口内。
|
||||||
|
type blockingExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
execCalls int32
|
||||||
|
queryCalls int32
|
||||||
|
release chan struct{}
|
||||||
|
concurrent int32
|
||||||
|
maxObserved int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingExecutor() *blockingExecutor {
|
||||||
|
return &blockingExecutor{release: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *blockingExecutor) Release() { close(e.release) }
|
||||||
|
|
||||||
|
func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||||
|
atomic.AddInt32(&e.execCalls, 1)
|
||||||
|
c := atomic.AddInt32(&e.concurrent, 1)
|
||||||
|
for {
|
||||||
|
old := atomic.LoadInt32(&e.maxObserved)
|
||||||
|
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer atomic.AddInt32(&e.concurrent, -1)
|
||||||
|
<-e.release
|
||||||
|
return driverResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) {
|
||||||
|
atomic.AddInt32(&e.queryCalls, 1)
|
||||||
|
return nil, sql.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
|
// driverResult 是一个零值 sql.Result,用于测试。
|
||||||
|
type driverResult struct{}
|
||||||
|
|
||||||
|
func (driverResult) LastInsertId() (int64, error) { return 0, nil }
|
||||||
|
func (driverResult) RowsAffected() (int64, error) { return 1, nil }
|
||||||
|
|
||||||
|
func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||||
|
// 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际
|
||||||
|
// SQL 路径(UPDATE + outbox INSERT = 2 次 ExecContext)。
|
||||||
|
exec := newBlockingExecutor()
|
||||||
|
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||||
|
|
||||||
|
const callers = 30
|
||||||
|
until := time.Now().Add(10 * time.Minute)
|
||||||
|
const reason = "OAuth 401: invalid_grant"
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(callers)
|
||||||
|
for i := 0; i < callers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = repo.SetTempUnschedulable(context.Background(), 42, until, reason)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent),
|
||||||
|
"singleflight should serialize the SQL call to exactly one in-flight execution")
|
||||||
|
|
||||||
|
exec.Release()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec;其余 29 个 caller 共享结果。
|
||||||
|
require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2),
|
||||||
|
"expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved),
|
||||||
|
"no two SQL execs should run concurrently for the same singleflight key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) {
|
||||||
|
// 不同 account 应分属不同 sf key,能并行写库。
|
||||||
|
exec := newBlockingExecutor()
|
||||||
|
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||||
|
|
||||||
|
until := time.Now().Add(10 * time.Minute)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := int64(1); i <= 3; i++ {
|
||||||
|
i := i
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved),
|
||||||
|
"different accounts should be able to write in parallel")
|
||||||
|
|
||||||
|
exec.Release()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@ -38,6 +38,7 @@ func ProvideRouter(
|
|||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
opsService *service.OpsService,
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
|
healthService *service.HealthService,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
if cfg.Server.Mode == "release" {
|
if cfg.Server.Mode == "release" {
|
||||||
@ -95,7 +96,7 @@ func ProvideRouter(
|
|||||||
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
|
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
|
||||||
})
|
})
|
||||||
|
|
||||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideHTTPServer 提供 HTTP 服务器
|
// ProvideHTTPServer 提供 HTTP 服务器
|
||||||
|
|||||||
@ -30,6 +30,7 @@ func SetupRouter(
|
|||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
opsService *service.OpsService,
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
|
healthService *service.HealthService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
@ -81,7 +82,7 @@ func SetupRouter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -97,11 +98,12 @@ func registerRoutes(
|
|||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
opsService *service.OpsService,
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
|
healthService *service.HealthService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) {
|
) {
|
||||||
// 通用路由(健康检查、状态等)
|
// 通用路由(健康检查、状态等)
|
||||||
routes.RegisterCommonRoutes(r)
|
routes.RegisterCommonRoutes(r, healthService)
|
||||||
|
|
||||||
// API v1
|
// API v1
|
||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
|
|||||||
@ -151,10 +151,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
|
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WebSocket realtime (QPS/TPS)
|
// WebSocket realtime (QPS/TPS and request stream)
|
||||||
ws := ops.Group("/ws")
|
ws := ops.Group("/ws")
|
||||||
{
|
{
|
||||||
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||||
|
ws.GET("/requests", h.Admin.Ops.RequestStreamWSHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error logs (legacy)
|
// Error logs (legacy)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,11 +17,37 @@ const (
|
|||||||
claudeCodeGrowthBookDateUpdated = "1970-01-01T00:00:00Z"
|
claudeCodeGrowthBookDateUpdated = "1970-01-01T00:00:00Z"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
// readinessHandlerTimeout 限定 readiness 端点对外的最大返回耗时。
|
||||||
func RegisterCommonRoutes(r *gin.Engine) {
|
// HealthService 内部对每个组件再有独立超时,所以这里给宽一点即可。
|
||||||
// 健康检查
|
const readinessHandlerTimeout = 3 * time.Second
|
||||||
r.GET("/health", func(c *gin.Context) {
|
|
||||||
|
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)。
|
||||||
|
//
|
||||||
|
// 健康端点的语义分层:
|
||||||
|
// - /healthz : liveness 探针。零依赖、永远 200。容器/进程探活专用。
|
||||||
|
// - /ready : readiness 探针。检查 DB+Redis;任一失败返回 503。
|
||||||
|
// - /health : 历史端点,等价于 /healthz,保留向后兼容。
|
||||||
|
//
|
||||||
|
// dashboard 用的"业务健康分"由 ops_health_score 单独提供,与本路由无关。
|
||||||
|
func RegisterCommonRoutes(r *gin.Engine, healthService *service.HealthService) {
|
||||||
|
// Liveness:仅证明进程在响应。
|
||||||
|
livenessHandler := func(c *gin.Context) {
|
||||||
|
_ = healthService.Liveness()
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
}
|
||||||
|
r.GET("/healthz", livenessHandler)
|
||||||
|
r.GET("/health", livenessHandler) // 向后兼容旧的 docker-compose healthcheck
|
||||||
|
|
||||||
|
// Readiness:检查关键依赖。失败时返回 503 但仍带详情,便于排障。
|
||||||
|
r.GET("/ready", func(c *gin.Context) {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), readinessHandlerTimeout)
|
||||||
|
defer cancel()
|
||||||
|
report := healthService.Readiness(ctx)
|
||||||
|
status := http.StatusOK
|
||||||
|
if !report.OK {
|
||||||
|
status = http.StatusServiceUnavailable
|
||||||
|
}
|
||||||
|
c.JSON(status, report)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
|
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
|
||||||
|
|||||||
@ -7,16 +7,56 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newCommonRoutesTestRouter() *gin.Engine {
|
func newCommonRoutesTestRouter() *gin.Engine {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
RegisterCommonRoutes(r)
|
RegisterCommonRoutes(r, service.NewHealthService(nil, nil))
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestRouter(t *testing.T, hs *service.HealthService) *gin.Engine {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
r := gin.New()
|
||||||
|
RegisterCommonRoutes(r, hs)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonRoutes_LivenessEndpoints(t *testing.T) {
|
||||||
|
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||||
|
for _, path := range []string{"/healthz", "/health"} {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "liveness path %s should be 200", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonRoutes_ReadyEndpoint_NoDepsReturnsOK(t *testing.T) {
|
||||||
|
// 没有 DB/Redis 依赖时 readiness 视为 ok(早期启动场景)。
|
||||||
|
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ready", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Contains(t, w.Body.String(), "\"ok\":true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonRoutes_SetupStatusUnchanged(t *testing.T) {
|
||||||
|
// 验证我们没有破坏既有的 /setup/status 行为(前端依赖)。
|
||||||
|
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/setup/status", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Contains(t, w.Body.String(), "needs_setup")
|
||||||
|
}
|
||||||
|
|
||||||
func TestCommonRoutes_ClaudeCodeBootstrap(t *testing.T) {
|
func TestCommonRoutes_ClaudeCodeBootstrap(t *testing.T) {
|
||||||
r := newCommonRoutesTestRouter()
|
r := newCommonRoutesTestRouter()
|
||||||
|
|
||||||
|
|||||||
@ -988,6 +988,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsContextCompressionEnabled returns true if the account has opted into proactive
|
||||||
|
// context compression. When enabled, the gateway will trim oldest messages before
|
||||||
|
// dispatch to keep the estimated token count within the configured budget.
|
||||||
|
func (a *Account) IsContextCompressionEnabled() bool {
|
||||||
|
if a.Credentials == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
enabled, _ := a.Credentials["enable_context_compression"].(bool)
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsBedrock() bool {
|
func (a *Account) IsBedrock() bool {
|
||||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||||
}
|
}
|
||||||
@ -1572,6 +1583,24 @@ func (a *Account) GetQuotaUsed() float64 {
|
|||||||
return a.getExtraFloat64("quota_used")
|
return a.getExtraFloat64("quota_used")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetQuotaRemainingFraction returns the fraction of total quota remaining in [0,1].
|
||||||
|
// Returns 1.0 when no quota limit is set (limit == 0 means unlimited).
|
||||||
|
func (a *Account) GetQuotaRemainingFraction() float64 {
|
||||||
|
limit := a.GetQuotaLimit()
|
||||||
|
if limit <= 0 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
used := a.GetQuotaUsed()
|
||||||
|
remaining := (limit - used) / limit
|
||||||
|
if remaining < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if remaining > 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return remaining
|
||||||
|
}
|
||||||
|
|
||||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||||
return a.getExtraFloat64("quota_daily_limit")
|
return a.getExtraFloat64("quota_daily_limit")
|
||||||
|
|||||||
@ -40,6 +40,7 @@ func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
|
|||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
{"标准版本号", "claude-cli/1.0.0", true},
|
{"标准版本号", "claude-cli/1.0.0", true},
|
||||||
|
{"官方 transport UA", "claude-code/2.1.88", true},
|
||||||
{"多位版本号", "claude-cli/12.34.56", true},
|
{"多位版本号", "claude-cli/12.34.56", true},
|
||||||
{"大写开头", "Claude-CLI/1.0.0", true},
|
{"大写开头", "Claude-CLI/1.0.0", true},
|
||||||
{"非 claude-cli", "curl/7.64.1", false},
|
{"非 claude-cli", "curl/7.64.1", false},
|
||||||
@ -90,6 +91,19 @@ func TestValidate_MessagesPath_FullValid(t *testing.T) {
|
|||||||
require.True(t, result, "完整有效请求应通过")
|
require.True(t, result, "完整有效请求应通过")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidate_MessagesPath_FullValid_ClaudeCodeUA(t *testing.T) {
|
||||||
|
v := newTestValidator()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||||
|
req.Header.Set("User-Agent", "claude-code/2.1.88")
|
||||||
|
req.Header.Set("X-App", "claude-code")
|
||||||
|
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
||||||
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
|
||||||
|
result := v.Validate(req, validClaudeCodeBody())
|
||||||
|
require.True(t, result, "官方 transport/helper UA 也应通过")
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
|
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
|
||||||
v := newTestValidator()
|
v := newTestValidator()
|
||||||
body := validClaudeCodeBody()
|
body := validClaudeCodeBody()
|
||||||
|
|||||||
@ -15,11 +15,13 @@ import (
|
|||||||
type ClaudeCodeValidator struct{}
|
type ClaudeCodeValidator struct{}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
// User-Agent 匹配: 官方 Claude Code 目前存在两类产品前缀:
|
||||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
// 1. 主 Anthropic API 客户端: claude-cli/x.y.z (...)
|
||||||
|
// 2. transport / helper 请求: claude-code/x.y.z
|
||||||
|
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/\d+\.\d+\.\d+`)
|
||||||
|
|
||||||
// 带捕获组的版本提取正则
|
// 带捕获组的版本提取正则
|
||||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/(\d+\.\d+\.\d+)`)
|
||||||
|
|
||||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||||
systemPromptThreshold = 0.5
|
systemPromptThreshold = 0.5
|
||||||
@ -55,7 +57,7 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
|||||||
// Validate 验证请求是否来自 Claude Code CLI
|
// Validate 验证请求是否来自 Claude Code CLI
|
||||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||||
//
|
//
|
||||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
// Step 1: User-Agent 检查 (必需) - 必须是官方 claude-cli/ 或 claude-code/ 前缀
|
||||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
||||||
// Step 4: 对于 messages 路径,进行严格验证:
|
// Step 4: 对于 messages 路径,进行严格验证:
|
||||||
|
|||||||
@ -64,6 +64,7 @@ func TestExtractVersion(t *testing.T) {
|
|||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
|
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
|
||||||
|
{"claude-code/2.1.88", "2.1.88"},
|
||||||
{"claude-cli/1.0.0", "1.0.0"},
|
{"claude-cli/1.0.0", "1.0.0"},
|
||||||
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
|
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
|
||||||
{"curl/8.0.0", ""}, // 非 Claude CLI
|
{"curl/8.0.0", ""}, // 非 Claude CLI
|
||||||
|
|||||||
151
backend/internal/service/context_compressor.go
Normal file
151
backend/internal/service/context_compressor.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
|
||||||
|
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
|
||||||
|
const defaultContextCompressionMaxTokens = 190_000
|
||||||
|
|
||||||
|
// approxTokens estimates the token count for a string using the chars/4 heuristic.
|
||||||
|
func approxTokens(s string) int {
|
||||||
|
return int(math.Ceil(float64(len(s)) / 4.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressMessagesInBody trims the oldest messages from the request body so that the
|
||||||
|
// estimated token count of the messages array fits within maxTokens.
|
||||||
|
// Returns the original body unchanged if no compression is needed or if parsing fails.
|
||||||
|
func compressMessagesInBody(body []byte, maxTokens int) []byte {
|
||||||
|
msgsResult := gjson.GetBytes(body, "messages")
|
||||||
|
if !msgsResult.Exists() || !msgsResult.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal to a typed slice for processing.
|
||||||
|
var messages []map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, changed := compressMessages(messages, maxTokens)
|
||||||
|
if !changed {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
newMsgs, err := json.Marshal(compressed)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressMessages removes the oldest messages from the front of msgs until the
|
||||||
|
// estimated total token count is at or below maxTokens.
|
||||||
|
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
|
||||||
|
// to avoid orphaned tool_result blocks.
|
||||||
|
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
|
||||||
|
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate total tokens.
|
||||||
|
totalTokens := 0
|
||||||
|
for _, m := range msgs {
|
||||||
|
totalTokens += msgTokens(m)
|
||||||
|
}
|
||||||
|
if totalTokens <= maxTokens {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
|
||||||
|
type unit struct {
|
||||||
|
startIdx int
|
||||||
|
endIdx int // exclusive
|
||||||
|
tokens int
|
||||||
|
}
|
||||||
|
units := make([]unit, 0, len(msgs))
|
||||||
|
i := 0
|
||||||
|
for i < len(msgs) {
|
||||||
|
toks := msgTokens(msgs[i])
|
||||||
|
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
|
||||||
|
toks += msgTokens(msgs[i+1])
|
||||||
|
units = append(units, unit{i, i + 2, toks})
|
||||||
|
i += 2
|
||||||
|
} else {
|
||||||
|
units = append(units, unit{i, i + 1, toks})
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove units from the front until we are within budget.
|
||||||
|
// Always keep at least the last unit so we never send an empty messages array.
|
||||||
|
removeCount := 0
|
||||||
|
for removeCount < len(units)-1 && totalTokens > maxTokens {
|
||||||
|
totalTokens -= units[removeCount].tokens
|
||||||
|
removeCount++
|
||||||
|
}
|
||||||
|
if removeCount == 0 {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cutIdx := units[removeCount].startIdx
|
||||||
|
return msgs[cutIdx:], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// msgTokens estimates token count for a single message using the chars/4 heuristic.
|
||||||
|
func msgTokens(msg map[string]any) int {
|
||||||
|
b, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return approxTokens(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAssistantWithToolUse returns true if msg is an assistant message whose content
|
||||||
|
// contains at least one block with "type": "tool_use".
|
||||||
|
func isAssistantWithToolUse(msg map[string]any) bool {
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if role != "assistant" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return contentContainsType(msg["content"], "tool_use")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUserWithToolResult returns true if msg is a user message whose content
|
||||||
|
// contains at least one block with "type": "tool_result".
|
||||||
|
func isUserWithToolResult(msg map[string]any) bool {
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if role != "user" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return contentContainsType(msg["content"], "tool_result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// contentContainsType returns true if content (a []any of blocks) contains a block
|
||||||
|
// whose "type" field equals blockType.
|
||||||
|
func contentContainsType(content any, blockType string) bool {
|
||||||
|
blocks, ok := content.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, b := range blocks {
|
||||||
|
block, ok := b.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if t, _ := block["type"].(string); t == blockType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
195
backend/internal/service/context_compressor_test.go
Normal file
195
backend/internal/service/context_compressor_test.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
func makeMsg(role, text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeToolUseMsg(id string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": id,
|
||||||
|
"name": "search",
|
||||||
|
"input": map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeToolResultMsg(toolUseID string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": toolUseID,
|
||||||
|
"content": "result text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toAnySlice(msgs []map[string]any) []any {
|
||||||
|
out := make([]any, len(msgs))
|
||||||
|
for i, m := range msgs {
|
||||||
|
out[i] = m
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
|
||||||
|
t.Helper()
|
||||||
|
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// tests
|
||||||
|
|
||||||
|
func TestApproxTokens(t *testing.T) {
|
||||||
|
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
|
||||||
|
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
|
||||||
|
assert.Equal(t, 0, approxTokens(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "hi"),
|
||||||
|
makeMsg("assistant", "hello"),
|
||||||
|
}
|
||||||
|
result, changed := compressMessages(msgs, 100_000)
|
||||||
|
assert.False(t, changed)
|
||||||
|
assert.Len(t, result, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
|
||||||
|
// 10 messages, each large enough to be over a tight budget when combined.
|
||||||
|
msgs := make([]map[string]any, 10)
|
||||||
|
for i := range msgs {
|
||||||
|
role := "user"
|
||||||
|
if i%2 == 1 {
|
||||||
|
role = "assistant"
|
||||||
|
}
|
||||||
|
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force compression by using a very small token budget.
|
||||||
|
result, changed := compressMessages(msgs, 1)
|
||||||
|
assert.True(t, changed)
|
||||||
|
// Must keep at least one message (the last).
|
||||||
|
assert.GreaterOrEqual(t, len(result), 1)
|
||||||
|
// The remaining messages should be from the tail (newest).
|
||||||
|
lastOrig := msgs[len(msgs)-1]["content"]
|
||||||
|
lastResult := result[len(result)-1]["content"]
|
||||||
|
assert.Equal(t, lastOrig, lastResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
|
||||||
|
// Messages: user → assistant+tool_use → user+tool_result → assistant
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "start"),
|
||||||
|
makeToolUseMsg("tool-1"),
|
||||||
|
makeToolResultMsg("tool-1"),
|
||||||
|
makeMsg("assistant", "done"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget that forces removal of the first non-paired message but keeps the tool pair.
|
||||||
|
// Estimate total tokens and set budget to force removing only "start" but not the pair.
|
||||||
|
total := 0
|
||||||
|
for _, m := range msgs {
|
||||||
|
total += msgTokens(m)
|
||||||
|
}
|
||||||
|
// Budget: remove "start" but keep tool pair + "done".
|
||||||
|
startTokens := msgTokens(msgs[0])
|
||||||
|
budget := total - startTokens
|
||||||
|
|
||||||
|
result, changed := compressMessages(msgs, budget)
|
||||||
|
assert.True(t, changed)
|
||||||
|
|
||||||
|
// tool_use and tool_result should both be present or both absent.
|
||||||
|
hasToolUse := false
|
||||||
|
hasToolResult := false
|
||||||
|
for _, m := range result {
|
||||||
|
if isAssistantWithToolUse(m) {
|
||||||
|
hasToolUse = true
|
||||||
|
}
|
||||||
|
if isUserWithToolResult(m) {
|
||||||
|
hasToolResult = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
|
||||||
|
// Budget forces removal of the tool pair.
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "start"),
|
||||||
|
makeToolUseMsg("tool-1"),
|
||||||
|
makeToolResultMsg("tool-1"),
|
||||||
|
makeMsg("assistant", "final answer after tool use"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget: only keep the last "assistant" message.
|
||||||
|
lastTokens := msgTokens(msgs[len(msgs)-1])
|
||||||
|
|
||||||
|
result, changed := compressMessages(msgs, lastTokens)
|
||||||
|
assert.True(t, changed)
|
||||||
|
|
||||||
|
// Neither tool_use nor tool_result should remain.
|
||||||
|
for _, m := range result {
|
||||||
|
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
|
||||||
|
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
|
||||||
|
result := compressMessagesInBody(body, 1)
|
||||||
|
assert.Equal(t, body, result, "body without messages should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
|
||||||
|
msgs := []map[string]any{makeMsg("user", "hi")}
|
||||||
|
body := bodyWithMessages(t, msgs)
|
||||||
|
result := compressMessagesInBody(body, 100_000)
|
||||||
|
assert.Equal(t, body, result, "body under budget should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
|
||||||
|
msgs := make([]map[string]any, 20)
|
||||||
|
for i := range msgs {
|
||||||
|
role := "user"
|
||||||
|
if i%2 == 1 {
|
||||||
|
role = "assistant"
|
||||||
|
}
|
||||||
|
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
|
||||||
|
}
|
||||||
|
body := bodyWithMessages(t, msgs)
|
||||||
|
|
||||||
|
// Force significant compression.
|
||||||
|
result := compressMessagesInBody(body, 50)
|
||||||
|
assert.Less(t, len(result), len(body), "compressed body should be smaller")
|
||||||
|
|
||||||
|
// Resulting body should still be valid JSON with a messages array.
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||||
|
resultMsgs, ok := parsed["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
|
||||||
|
}
|
||||||
@ -689,6 +689,41 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
|
|||||||
require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
sessionID := "12345678-1234-1234-1234-123456789abc"
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
|
"metadata": map[string]any{
|
||||||
|
"user_id": FormatMetadataUserID(
|
||||||
|
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||||
|
"",
|
||||||
|
sessionID,
|
||||||
|
claude.DefaultCLIProductVersion,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-3-7-sonnet-20250219", false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, sessionID, getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -335,8 +335,9 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
|||||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var (
|
var (
|
||||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||||
|
claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`)
|
||||||
|
|
||||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||||
@ -569,7 +570,8 @@ type GatewayService struct {
|
|||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
|
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
|
||||||
userGroupRateResolver *userGroupRateResolver
|
userGroupRateResolver *userGroupRateResolver
|
||||||
userGroupRateCache *gocache.Cache
|
userGroupRateCache *gocache.Cache
|
||||||
userGroupRateSF singleflight.Group
|
userGroupRateSF singleflight.Group
|
||||||
@ -614,6 +616,7 @@ func NewGatewayService(
|
|||||||
channelService *ChannelService,
|
channelService *ChannelService,
|
||||||
resolver *ModelPricingResolver,
|
resolver *ModelPricingResolver,
|
||||||
balanceNotifyService *BalanceNotifyService,
|
balanceNotifyService *BalanceNotifyService,
|
||||||
|
rpmTokenBucketSvc *RPMTokenBucketService,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||||
@ -640,6 +643,7 @@ func NewGatewayService(
|
|||||||
claudeTokenProvider: claudeTokenProvider,
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
sessionLimitCache: sessionLimitCache,
|
sessionLimitCache: sessionLimitCache,
|
||||||
rpmCache: rpmCache,
|
rpmCache: rpmCache,
|
||||||
|
rpmTokenBucket: rpmTokenBucketSvc,
|
||||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||||
@ -1374,6 +1378,9 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||||
|
if platform == PlatformAnthropic && s.enableTierFallbackChain() {
|
||||||
|
return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||||
|
}
|
||||||
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -2558,6 +2565,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
|
||||||
|
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
|
||||||
|
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||||
|
if s.rpmTokenBucket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
|
||||||
|
}
|
||||||
|
|
||||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||||
@ -3765,6 +3781,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
|||||||
return ParseMetadataUserID(metadataUserID) != nil
|
return ParseMetadataUserID(metadataUserID) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||||
|
if IsClaudeCodeClient(ctx) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if parsed == nil || c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
||||||
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
|
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
|
||||||
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
|
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
|
||||||
@ -4323,6 +4349,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||||
body = StripEmptyTextBlocks(body)
|
body = StripEmptyTextBlocks(body)
|
||||||
|
|
||||||
|
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
|
||||||
|
if account.IsContextCompressionEnabled() {
|
||||||
|
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
|
||||||
|
body = compressMessagesInBody(body, maxTok)
|
||||||
|
}
|
||||||
|
|
||||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||||
setOpsUpstreamRequestBody(c, body)
|
setOpsUpstreamRequestBody(c, body)
|
||||||
|
|
||||||
@ -5923,15 +5955,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
// X-Claude-Code-Session-Id 头处理:
|
// X-Claude-Code-Session-Id 头处理:
|
||||||
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id
|
// Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
|
||||||
// 2. 客户端未提供(mimic 模式)→ 生成确定性 per-account session UUID
|
// 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
|
||||||
// 真实 CLI 每个请求都携带此 header(per-process UUID)。
|
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
|
||||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
|
||||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -8969,12 +8997,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
// Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
|
||||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
// 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
|
||||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
133
backend/internal/service/gateway_tier_fallback.go
Normal file
133
backend/internal/service/gateway_tier_fallback.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// accountTierLevel maps an account type to a scheduling tier:
|
||||||
|
//
|
||||||
|
// 0 = subscription (OAuth / SetupToken) — tried first
|
||||||
|
// 1 = API Key — first fallback
|
||||||
|
// 2 = Bedrock — last resort
|
||||||
|
//
|
||||||
|
// Accounts with an unknown type fall into tier 0 so they participate in the
|
||||||
|
// primary selection and do not vanish silently.
|
||||||
|
func accountTierLevel(account *Account) int {
|
||||||
|
if account == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch account.Type {
|
||||||
|
case AccountTypeAPIKey:
|
||||||
|
return 1
|
||||||
|
case AccountTypeBedrock:
|
||||||
|
return 2
|
||||||
|
default: // OAuth, SetupToken, or unknown
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableTierFallbackChain reports whether the cross-tier fallback chain is
|
||||||
|
// enabled in config (default false).
|
||||||
|
func (s *GatewayService) enableTierFallbackChain() bool {
|
||||||
|
return s != nil && s.cfg != nil && s.cfg.Gateway.Scheduling.EnableTierFallbackChain
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectAccountWithTierFallback tries Anthropic accounts in tier order:
|
||||||
|
// tier 0 (OAuth/SetupToken subscription) → tier 1 (API Key) → tier 2 (Bedrock).
|
||||||
|
//
|
||||||
|
// Sticky sessions are honored within the chain: if the session-bound account is
|
||||||
|
// in a tier that still has capacity it is returned immediately; otherwise the
|
||||||
|
// session binding is cleared and the chain proceeds from tier 0.
|
||||||
|
func (s *GatewayService) selectAccountWithTierFallback(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
sessionHash string,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
) (*Account, error) {
|
||||||
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAnthropic, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||||
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
|
// Build per-tier candidate lists (pointers into `accounts`).
|
||||||
|
const numTiers = 3
|
||||||
|
tierCandidates := [numTiers][]*Account{}
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
if acc.Platform != PlatformAnthropic {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForQuota(acc) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tier := accountTierLevel(acc)
|
||||||
|
if tier < numTiers {
|
||||||
|
tierCandidates[tier] = append(tierCandidates[tier], acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := s.schedulingConfig()
|
||||||
|
selectionMode := cfg.FallbackSelectionMode
|
||||||
|
|
||||||
|
// Check sticky session: if the bound account is a valid candidate, use it.
|
||||||
|
if sessionHash != "" && s.cache != nil {
|
||||||
|
accountID, cacheErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if cacheErr == nil && accountID > 0 {
|
||||||
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
|
for tier := 0; tier < numTiers; tier++ {
|
||||||
|
for _, acc := range tierCandidates[tier] {
|
||||||
|
if acc.ID != accountID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if shouldClearStickySession(acc, requestedModel) {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s.isAccountSchedulableForWindowCost(ctx, acc, true) &&
|
||||||
|
s.isAccountSchedulableForRPM(ctx, acc, true) {
|
||||||
|
return acc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each tier in order.
|
||||||
|
for tier := 0; tier < numTiers; tier++ {
|
||||||
|
candidates := tierCandidates[tier]
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.sortCandidatesForFallback(candidates, false, selectionMode)
|
||||||
|
result, acquired, _ := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, false)
|
||||||
|
if acquired && result != nil {
|
||||||
|
return result.Account, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no available accounts in any tier")
|
||||||
|
}
|
||||||
138
backend/internal/service/gateway_tier_fallback_test.go
Normal file
138
backend/internal/service/gateway_tier_fallback_test.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTierLevel(t *testing.T) {
|
||||||
|
require.Equal(t, 0, accountTierLevel(nil))
|
||||||
|
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeOAuth}))
|
||||||
|
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeSetupToken}))
|
||||||
|
require.Equal(t, 0, accountTierLevel(&Account{Type: "unknown"}))
|
||||||
|
require.Equal(t, 1, accountTierLevel(&Account{Type: AccountTypeAPIKey}))
|
||||||
|
require.Equal(t, 2, accountTierLevel(&Account{Type: AccountTypeBedrock}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_EnableTierFallbackChain(t *testing.T) {
|
||||||
|
require.False(t, (*GatewayService)(nil).enableTierFallbackChain())
|
||||||
|
require.False(t, (&GatewayService{}).enableTierFallbackChain())
|
||||||
|
|
||||||
|
cfgOff := &config.Config{}
|
||||||
|
cfgOff.Gateway.Scheduling.EnableTierFallbackChain = false
|
||||||
|
require.False(t, (&GatewayService{cfg: cfgOff}).enableTierFallbackChain())
|
||||||
|
|
||||||
|
cfgOn := &config.Config{}
|
||||||
|
cfgOn.Gateway.Scheduling.EnableTierFallbackChain = true
|
||||||
|
require.True(t, (&GatewayService{cfg: cfgOn}).enableTierFallbackChain())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription verifies
|
||||||
|
// that when both OAuth (subscription) and APIKey accounts are available, the
|
||||||
|
// tier-0 OAuth account is always selected first even if APIKey has higher priority.
|
||||||
|
func TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
oauthAcc := Account{ID: 91001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Priority: 5}
|
||||||
|
apiKeyAcc := Account{ID: 91002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Priority: 0}
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||||
|
accountsByID: map[int64]*Account{91001: &oauthAcc, 91002: &apiKeyAcc},
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(91001), acc.ID, "OAuth (tier-0) account should be preferred over APIKey (tier-1)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey verifies
|
||||||
|
// that when the subscription tier has no schedulable accounts, the fallback
|
||||||
|
// selects an API Key account.
|
||||||
|
func TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||||
|
oauthAcc := Account{ID: 92001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||||
|
apiKeyAcc := Account{ID: 92002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||||
|
accountsByID: map[int64]*Account{92001: &oauthAcc, 92002: &apiKeyAcc},
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(92002), acc.ID, "Should fall back to APIKey when OAuth is rate-limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts ensures
|
||||||
|
// excluded IDs are respected across all tiers.
|
||||||
|
func TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
oauthAcc := Account{ID: 93001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}
|
||||||
|
apiKeyAcc := Account{ID: 93002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||||
|
accountsByID: map[int64]*Account{93001: &oauthAcc, 93002: &apiKeyAcc},
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||||
|
|
||||||
|
excluded := map[int64]struct{}{93001: {}}
|
||||||
|
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", excluded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(93002), acc.ID, "Excluded OAuth account should cause APIKey fallback")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithTierFallback_NoAccounts verifies that
|
||||||
|
// an error is returned when all tiers are empty.
|
||||||
|
func TestGatewayService_SelectAccountWithTierFallback_NoAccounts(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{accounts: nil, accountsByID: map[int64]*Account{}}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort verifies
|
||||||
|
// that Bedrock accounts are only used when subscription and API Key tiers are exhausted.
|
||||||
|
func TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||||
|
oauthAcc := Account{ID: 94001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||||
|
apiKeyAcc := Account{ID: 94002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||||
|
bedrockAcc := Account{ID: 94003, Platform: PlatformAnthropic, Type: AccountTypeBedrock, Status: StatusActive, Schedulable: true}
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForPlatform{
|
||||||
|
accounts: []Account{oauthAcc, apiKeyAcc, bedrockAcc},
|
||||||
|
accountsByID: map[int64]*Account{94001: &oauthAcc, 94002: &apiKeyAcc, 94003: &bedrockAcc},
|
||||||
|
}
|
||||||
|
cache := &mockGatewayCacheForPlatform{}
|
||||||
|
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||||
|
|
||||||
|
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(94003), acc.ID, "Bedrock should be selected as last resort")
|
||||||
|
}
|
||||||
119
backend/internal/service/health_service.go
Normal file
119
backend/internal/service/health_service.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
// Package service - HealthService 提供 liveness 与 readiness 探针。
|
||||||
|
//
|
||||||
|
// 设计动机:原有 /health 端点既被 docker-compose healthcheck 使用,又被
|
||||||
|
// dashboard 的 ops_health_score 复用——后者会触发 DB/Redis 等重操作,
|
||||||
|
// 导致探活流量污染监控指标。本服务把两类语义拆开:
|
||||||
|
// - Liveness : 仅证明进程存活(无外部依赖检查)。
|
||||||
|
// - Readiness : 检查 DB + Redis 连通,作为是否可接收流量的判断。
|
||||||
|
//
|
||||||
|
// dashboard 维度的"业务健康分"仍由 ops_health_score 计算,与本服务无关。
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 探针默认超时。Readiness 探针需要快速失败,避免堆积。
|
||||||
|
const (
|
||||||
|
defaultReadinessTimeout = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadinessReport 描述各依赖项的状态,便于上层暴露细节给排障。
|
||||||
|
type ReadinessReport struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Details map[string]ComponentStatus `json:"details"`
|
||||||
|
Elapsed time.Duration `json:"elapsed_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComponentStatus 单个依赖项的状态。Error 字段在 OK=true 时为空。
|
||||||
|
type ComponentStatus struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Elapsed string `json:"elapsed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthService 提供 liveness/readiness 探针。
|
||||||
|
// 字段都允许为 nil:缺失的依赖在 readiness 中自动跳过,便于测试和分阶段启用。
|
||||||
|
type HealthService struct {
|
||||||
|
db *sql.DB
|
||||||
|
rdb *redis.Client
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHealthService 构造函数。timeout<=0 时使用默认值。
|
||||||
|
func NewHealthService(db *sql.DB, rdb *redis.Client) *HealthService {
|
||||||
|
return &HealthService{
|
||||||
|
db: db,
|
||||||
|
rdb: rdb,
|
||||||
|
timeout: defaultReadinessTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Liveness 仅返回 nil。任何调用方能拿到这个返回值就说明进程在响应请求。
|
||||||
|
// 保持无副作用、零依赖,便于 K8s livenessProbe 高频调用。
|
||||||
|
func (s *HealthService) Liveness() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Readiness 检查所有外部依赖。任一失败则整体 OK=false。
|
||||||
|
// 单个依赖的 ctx 超时由 timeout 控制,独立计时不互相阻塞。
|
||||||
|
func (s *HealthService) Readiness(ctx context.Context) ReadinessReport {
|
||||||
|
start := time.Now()
|
||||||
|
report := ReadinessReport{
|
||||||
|
OK: true,
|
||||||
|
Details: make(map[string]ComponentStatus, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.db != nil {
|
||||||
|
report.Details["database"] = s.checkDB(ctx)
|
||||||
|
if !report.Details["database"].OK {
|
||||||
|
report.OK = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.rdb != nil {
|
||||||
|
report.Details["redis"] = s.checkRedis(ctx)
|
||||||
|
if !report.Details["redis"].OK {
|
||||||
|
report.OK = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
report.Elapsed = time.Since(start)
|
||||||
|
return report
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HealthService) checkDB(parent context.Context) ComponentStatus {
|
||||||
|
ctx, cancel := context.WithTimeout(parent, s.timeout)
|
||||||
|
defer cancel()
|
||||||
|
start := time.Now()
|
||||||
|
err := s.db.PingContext(ctx)
|
||||||
|
status := ComponentStatus{Elapsed: time.Since(start).String()}
|
||||||
|
if err != nil {
|
||||||
|
status.Error = err.Error()
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
status.OK = true
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HealthService) checkRedis(parent context.Context) ComponentStatus {
|
||||||
|
ctx, cancel := context.WithTimeout(parent, s.timeout)
|
||||||
|
defer cancel()
|
||||||
|
start := time.Now()
|
||||||
|
pong, err := s.rdb.Ping(ctx).Result()
|
||||||
|
status := ComponentStatus{Elapsed: time.Since(start).String()}
|
||||||
|
if err != nil {
|
||||||
|
status.Error = err.Error()
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
if pong != "PONG" {
|
||||||
|
status.Error = errors.New("unexpected redis ping response: " + pong).Error()
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
status.OK = true
|
||||||
|
return status
|
||||||
|
}
|
||||||
93
backend/internal/service/health_service_test.go
Normal file
93
backend/internal/service/health_service_test.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthService_Liveness_AlwaysOK(t *testing.T) {
|
||||||
|
s := NewHealthService(nil, nil)
|
||||||
|
require.NoError(t, s.Liveness())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthService_Readiness_AllNilReturnsOK(t *testing.T) {
|
||||||
|
// 当所有依赖都为 nil 时(早期启动或 unit test),readiness 应直接 OK。
|
||||||
|
s := NewHealthService(nil, nil)
|
||||||
|
report := s.Readiness(context.Background())
|
||||||
|
require.True(t, report.OK)
|
||||||
|
require.Empty(t, report.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthService_Readiness_DBPingFails(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectPing().WillReturnError(errors.New("connection refused"))
|
||||||
|
|
||||||
|
s := NewHealthService(db, nil)
|
||||||
|
report := s.Readiness(context.Background())
|
||||||
|
require.False(t, report.OK)
|
||||||
|
require.Contains(t, report.Details, "database")
|
||||||
|
require.False(t, report.Details["database"].OK)
|
||||||
|
require.Contains(t, report.Details["database"].Error, "connection refused")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthService_Readiness_DBOK(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectPing()
|
||||||
|
|
||||||
|
s := NewHealthService(db, nil)
|
||||||
|
report := s.Readiness(context.Background())
|
||||||
|
require.True(t, report.OK)
|
||||||
|
require.True(t, report.Details["database"].OK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthService_Readiness_RedisFails(t *testing.T) {
|
||||||
|
// 指向一个不可达端口让 redis ping 立刻失败。
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 200 * time.Millisecond,
|
||||||
|
ReadTimeout: 200 * time.Millisecond,
|
||||||
|
})
|
||||||
|
defer rdb.Close()
|
||||||
|
|
||||||
|
s := NewHealthService(nil, rdb)
|
||||||
|
s.timeout = 500 * time.Millisecond
|
||||||
|
report := s.Readiness(context.Background())
|
||||||
|
require.False(t, report.OK)
|
||||||
|
require.Contains(t, report.Details, "redis")
|
||||||
|
require.False(t, report.Details["redis"].OK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthService_Readiness_PerComponentTimeout(t *testing.T) {
|
||||||
|
// 验证 readiness 在超时时不会无限挂住。
|
||||||
|
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
mock.ExpectPing().WillDelayFor(2 * time.Second)
|
||||||
|
|
||||||
|
s := NewHealthService(db, nil)
|
||||||
|
s.timeout = 100 * time.Millisecond
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
report := s.Readiness(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Less(t, elapsed, 1*time.Second, "readiness should respect per-component timeout")
|
||||||
|
require.False(t, report.OK)
|
||||||
|
require.NotEmpty(t, report.Details["database"].Error, "timeout should propagate as an error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 抑制未使用包警告(database/sql 在签名里使用)。
|
||||||
|
var _ = sql.ErrNoRows
|
||||||
@ -8,6 +8,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||||
@ -30,12 +32,22 @@ type OAuthRefreshResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||||
// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑
|
// 封装分布式锁、进程内去重(singleflight)、DB 重读、已刷新检查、竞争恢复等通用逻辑
|
||||||
|
//
|
||||||
|
// 双层去重设计:
|
||||||
|
// 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine
|
||||||
|
// 都去抢同一把分布式锁、都重读一次 DB)。
|
||||||
|
// 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth
|
||||||
|
// 刷新请求。
|
||||||
|
//
|
||||||
|
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的,
|
||||||
|
// singleflight 解决不了多 pod 同时刷新。
|
||||||
type OAuthRefreshAPI struct {
|
type OAuthRefreshAPI struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache GeminiTokenCache // 可选,nil = 无分布式锁
|
tokenCache GeminiTokenCache // 可选,nil = 无分布式锁
|
||||||
lockTTL time.Duration
|
lockTTL time.Duration
|
||||||
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
|
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
|
||||||
|
sf singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||||
@ -63,15 +75,19 @@ func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
|
|||||||
return mu
|
return mu
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
|
||||||
|
//
|
||||||
|
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
|
||||||
|
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
|
||||||
//
|
//
|
||||||
// 流程:
|
// 流程:
|
||||||
// 1. 获取分布式锁
|
// 1. singleflight 合并同 cacheKey 并发调用
|
||||||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
// 2. 获取分布式锁(跨进程)
|
||||||
// 3. 二次检查是否仍需刷新
|
// 3. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
// 4. 二次检查是否仍需刷新
|
||||||
// 5. 设置 _token_version + 更新 DB
|
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||||
// 6. 释放锁
|
// 6. 设置 _token_version + 更新 DB
|
||||||
|
// 7. 释放锁
|
||||||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
account *Account,
|
account *Account,
|
||||||
@ -80,11 +96,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
|||||||
) (*OAuthRefreshResult, error) {
|
) (*OAuthRefreshResult, error) {
|
||||||
cacheKey := executor.CacheKey(account)
|
cacheKey := executor.CacheKey(account)
|
||||||
|
|
||||||
// 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争)
|
// singleflight key 同时区分 cacheKey 和 refreshWindow:
|
||||||
localMu := api.getLocalLock(cacheKey)
|
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh,
|
||||||
localMu.Lock()
|
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
|
||||||
defer localMu.Unlock()
|
sfKey := cacheKey + "|" + refreshWindow.String()
|
||||||
|
|
||||||
|
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
|
||||||
|
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, _ := v.(*OAuthRefreshResult)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
|
||||||
|
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
|
||||||
|
func (api *OAuthRefreshAPI) refreshOnce(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
executor OAuthRefreshExecutor,
|
||||||
|
refreshWindow time.Duration,
|
||||||
|
cacheKey string,
|
||||||
|
) (*OAuthRefreshResult, error) {
|
||||||
// 1. 获取分布式锁
|
// 1. 获取分布式锁
|
||||||
lockAcquired := false
|
lockAcquired := false
|
||||||
if api.tokenCache != nil {
|
if api.tokenCache != nil {
|
||||||
|
|||||||
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
|
||||||
|
type blockingExecutor struct {
|
||||||
|
refreshAPIExecutorStub
|
||||||
|
release chan struct{}
|
||||||
|
concurrent int32 // 当前正在 Refresh 的 goroutine 数
|
||||||
|
maxObserved int32 // 观察到的最大并发数
|
||||||
|
calls int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||||
|
atomic.AddInt32(&e.calls, 1)
|
||||||
|
c := atomic.AddInt32(&e.concurrent, 1)
|
||||||
|
for {
|
||||||
|
old := atomic.LoadInt32(&e.maxObserved)
|
||||||
|
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer atomic.AddInt32(&e.concurrent, -1)
|
||||||
|
|
||||||
|
<-e.release
|
||||||
|
return e.credentials, e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||||
|
// 同一 cacheKey 同时进入 N 个 goroutine,应只触发 1 次 executor.Refresh。
|
||||||
|
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
|
||||||
|
exec := &blockingExecutor{
|
||||||
|
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "new"},
|
||||||
|
},
|
||||||
|
release: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
|
||||||
|
const callers = 20
|
||||||
|
results := make([]*OAuthRefreshResult, callers)
|
||||||
|
errs := make([]error, callers)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(callers)
|
||||||
|
|
||||||
|
for i := 0; i < callers; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||||
|
results[i] = r
|
||||||
|
errs[i] = err
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
|
||||||
|
|
||||||
|
close(exec.release)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
|
||||||
|
|
||||||
|
// 所有 caller 应拿到等价结果(不必同实例,singleflight Shared 标志会让多个 caller 共享)。
|
||||||
|
for i := 0; i < callers; i++ {
|
||||||
|
require.NoError(t, errs[i])
|
||||||
|
require.NotNil(t, results[i])
|
||||||
|
require.True(t, results[i].Refreshed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
|
||||||
|
// 不同账号有不同 cacheKey,应能并行刷新而非互相阻塞。
|
||||||
|
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
|
||||||
|
exec := &blockingExecutor{
|
||||||
|
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "new"},
|
||||||
|
},
|
||||||
|
release: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
platform := "p" + string(rune('a'+i))
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
|
||||||
|
|
||||||
|
close(exec.release)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
|
||||||
|
// 同 cacheKey 但不同 refreshWindow(前台短窗口 vs 后台长窗口)应分开判断
|
||||||
|
// NeedsRefresh,避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
|
||||||
|
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||||
|
cache := &refreshAPICacheStub{lockResult: true}
|
||||||
|
|
||||||
|
exec := &blockingExecutor{
|
||||||
|
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||||
|
needsRefresh: true,
|
||||||
|
credentials: map[string]any{"access_token": "new"},
|
||||||
|
},
|
||||||
|
release: make(chan struct{}),
|
||||||
|
}
|
||||||
|
api := NewOAuthRefreshAPI(repo, cache)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
|
||||||
|
}()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
|
||||||
|
|
||||||
|
close(exec.release)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@ -730,12 +730,18 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||||
}
|
}
|
||||||
|
quotaFactor := item.account.GetQuotaRemainingFraction()
|
||||||
|
|
||||||
item.score = weights.Priority*priorityFactor +
|
item.score = weights.Priority*priorityFactor +
|
||||||
weights.Load*loadFactor +
|
weights.Load*loadFactor +
|
||||||
weights.Queue*queueFactor +
|
weights.Queue*queueFactor +
|
||||||
weights.ErrorRate*errorFactor +
|
weights.ErrorRate*errorFactor +
|
||||||
weights.TTFT*ttftFactor
|
weights.TTFT*ttftFactor +
|
||||||
|
weights.Quota*quotaFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.service.openAIWSP2CEnabled() {
|
||||||
|
return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1193,6 +1199,7 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
|
|||||||
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
||||||
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
||||||
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
||||||
|
Quota: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||||
@ -1201,15 +1208,21 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
|
|||||||
Queue: 0.7,
|
Queue: 0.7,
|
||||||
ErrorRate: 0.8,
|
ErrorRate: 0.8,
|
||||||
TTFT: 0.5,
|
TTFT: 0.5,
|
||||||
|
Quota: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) openAIWSP2CEnabled() bool {
|
||||||
|
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EnableP2CScheduling
|
||||||
|
}
|
||||||
|
|
||||||
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
||||||
Priority float64
|
Priority float64
|
||||||
Load float64
|
Load float64
|
||||||
Queue float64
|
Queue float64
|
||||||
ErrorRate float64
|
ErrorRate float64
|
||||||
TTFT float64
|
TTFT float64
|
||||||
|
Quota float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func clamp01(value float64) float64 {
|
func clamp01(value float64) float64 {
|
||||||
@ -1223,6 +1236,94 @@ func clamp01(value float64) float64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectByPowerOfTwo implements Power-of-Two-Choices (P2C): sample 2 random
|
||||||
|
// candidates and attempt the better-scored one first, then the other.
|
||||||
|
// This gives O(1) selection with load distribution comparable to top-K when N is large.
|
||||||
|
func (s *defaultOpenAIAccountScheduler) selectByPowerOfTwo(
|
||||||
|
ctx context.Context,
|
||||||
|
req OpenAIAccountScheduleRequest,
|
||||||
|
candidates []openAIAccountCandidateScore,
|
||||||
|
loadSkew float64,
|
||||||
|
) (*AccountSelectionResult, int, int, float64, error) {
|
||||||
|
n := len(candidates)
|
||||||
|
if n == 0 {
|
||||||
|
return nil, 0, 0, loadSkew, ErrNoAvailableAccounts
|
||||||
|
}
|
||||||
|
|
||||||
|
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
|
||||||
|
|
||||||
|
// Pick two distinct random indices.
|
||||||
|
idxA := int(rng.nextUint64() % uint64(n))
|
||||||
|
idxB := idxA
|
||||||
|
if n > 1 {
|
||||||
|
for idxB == idxA {
|
||||||
|
idxB = int(rng.nextUint64() % uint64(n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order: better candidate first.
|
||||||
|
first, second := candidates[idxA], candidates[idxB]
|
||||||
|
if isOpenAIAccountCandidateBetter(second, first) {
|
||||||
|
first, second = second, first
|
||||||
|
}
|
||||||
|
|
||||||
|
tryAcquire := func(c openAIAccountCandidateScore) (*AccountSelectionResult, bool, error) {
|
||||||
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
|
||||||
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, req.RequireCompact)
|
||||||
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
result, err := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if result != nil && result.Acquired {
|
||||||
|
if req.SessionHash != "" {
|
||||||
|
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: fresh,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range []openAIAccountCandidateScore{first, second} {
|
||||||
|
result, ok, err := tryAcquire(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, n, 2, loadSkew, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return result, n, 2, loadSkew, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both slots busy — return wait plan on the better candidate.
|
||||||
|
cfg := s.service.schedulingConfig()
|
||||||
|
for _, c := range []openAIAccountCandidateScore{first, second} {
|
||||||
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
|
||||||
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: fresh,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: fresh.ID,
|
||||||
|
MaxConcurrency: fresh.Concurrency,
|
||||||
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
|
},
|
||||||
|
}, n, 2, loadSkew, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, n, 2, loadSkew, ErrNoAvailableAccounts
|
||||||
|
}
|
||||||
|
|
||||||
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@ -1448,3 +1448,129 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
|
|||||||
func int64PtrForTest(v int64) *int64 {
|
func int64PtrForTest(v int64) *int64 {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetQuotaRemainingFraction(t *testing.T) {
|
||||||
|
// No limit configured → always 1.0 (unlimited)
|
||||||
|
noLimit := &Account{}
|
||||||
|
require.Equal(t, 1.0, noLimit.GetQuotaRemainingFraction())
|
||||||
|
|
||||||
|
// 50% used
|
||||||
|
half := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 50.0}}
|
||||||
|
require.InDelta(t, 0.5, half.GetQuotaRemainingFraction(), 1e-9)
|
||||||
|
|
||||||
|
// Fully exhausted
|
||||||
|
full := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 100.0}}
|
||||||
|
require.Equal(t, 0.0, full.GetQuotaRemainingFraction())
|
||||||
|
|
||||||
|
// Over limit → clamp to 0
|
||||||
|
over := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 150.0}}
|
||||||
|
require.Equal(t, 0.0, over.GetQuotaRemainingFraction())
|
||||||
|
|
||||||
|
// Fresh (0 used)
|
||||||
|
fresh := &Account{Extra: map[string]any{"quota_limit": 200.0, "quota_used": 0.0}}
|
||||||
|
require.Equal(t, 1.0, fresh.GetQuotaRemainingFraction())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_P2CEnabled(t *testing.T) {
|
||||||
|
require.False(t, (*OpenAIGatewayService)(nil).openAIWSP2CEnabled())
|
||||||
|
require.False(t, (&OpenAIGatewayService{}).openAIWSP2CEnabled())
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.OpenAIWS.EnableP2CScheduling = false
|
||||||
|
require.False(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
|
||||||
|
|
||||||
|
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||||
|
require.True(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SchedulerWeights_QuotaField(t *testing.T) {
|
||||||
|
// Default weights: Quota is 0 (disabled by default)
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
weights := svc.openAIWSSchedulerWeights()
|
||||||
|
require.Equal(t, 0.0, weights.Quota)
|
||||||
|
|
||||||
|
// Config-driven quota weight
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota = 0.4
|
||||||
|
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
require.Equal(t, 0.4, svcWithCfg.openAIWSSchedulerWeights().Quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_SingleCandidate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(99001)
|
||||||
|
account := &Account{ID: 71001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
|
||||||
|
snapshotCache := &openAISnapshotCacheStub{
|
||||||
|
snapshotAccounts: []*Account{account},
|
||||||
|
accountsByID: map[int64]*Account{71001: account},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||||
|
cfg.Gateway.OpenAIWS.LBTopK = 5
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
|
||||||
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.Equal(t, int64(71001), selection.Account.ID)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_PicksBetterCandidate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(99002)
|
||||||
|
// Account A has low priority (better), B has high priority (worse).
|
||||||
|
// With P2C enabled and a deterministic seed, we should always get a valid selection.
|
||||||
|
accountA := &Account{ID: 72001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
|
||||||
|
accountB := &Account{ID: 72002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 10}
|
||||||
|
snapshotCache := &openAISnapshotCacheStub{
|
||||||
|
snapshotAccounts: []*Account{accountA, accountB},
|
||||||
|
accountsByID: map[int64]*Account{72001: accountA, 72002: accountB},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*accountA, *accountB}},
|
||||||
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.NotNil(t, selection.Account)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
// Either account is valid; just verify we got a schedulable one.
|
||||||
|
require.True(t, selection.Account.ID == 72001 || selection.Account.ID == 72002)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultOpenAIAccountScheduler_QuotaFactorInfluencesScore(t *testing.T) {
|
||||||
|
// Verify that quota weight affects scoring by checking GetQuotaRemainingFraction is used.
|
||||||
|
// Account with high remaining quota should score higher when quota weight > 0.
|
||||||
|
highQuota := &Account{
|
||||||
|
ID: 73001, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
|
||||||
|
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 10.0}, // 90% remaining
|
||||||
|
}
|
||||||
|
lowQuota := &Account{
|
||||||
|
ID: 73002, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
|
||||||
|
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 90.0}, // 10% remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
require.InDelta(t, 0.9, highQuota.GetQuotaRemainingFraction(), 1e-9)
|
||||||
|
require.InDelta(t, 0.1, lowQuota.GetQuotaRemainingFraction(), 1e-9)
|
||||||
|
|
||||||
|
// With quota weight = 1.0 and all other weights = 0, high-quota account should win.
|
||||||
|
// We verify the score ordering directly using isOpenAIAccountCandidateBetter.
|
||||||
|
highScore := openAIAccountCandidateScore{account: highQuota, score: 0.9}
|
||||||
|
lowScore := openAIAccountCandidateScore{account: lowQuota, score: 0.1}
|
||||||
|
require.True(t, isOpenAIAccountCandidateBetter(highScore, lowScore))
|
||||||
|
require.False(t, isOpenAIAccountCandidateBetter(lowScore, highScore))
|
||||||
|
}
|
||||||
|
|||||||
75
backend/internal/service/request_event_bus.go
Normal file
75
backend/internal/service/request_event_bus.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const requestEventBufSize = 64
|
||||||
|
|
||||||
|
// RequestEvent is published for every gateway dispatch completion.
|
||||||
|
type RequestEvent struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
AccountID int64 `json:"account_id"`
|
||||||
|
// Status is "success", "error", or "rate_limited".
|
||||||
|
Status string `json:"status"`
|
||||||
|
LatencyMS int64 `json:"latency_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestEventBus is a fan-out hub for real-time request events.
|
||||||
|
// Publishers call Publish; subscribers call Subscribe/Unsubscribe.
|
||||||
|
// Each subscriber gets its own buffered channel. If the buffer is full
|
||||||
|
// the event is dropped for that subscriber (non-blocking publish).
|
||||||
|
type RequestEventBus struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
subscribers map[uint64]chan RequestEvent
|
||||||
|
nextID atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestEventBus() *RequestEventBus {
|
||||||
|
return &RequestEventBus{
|
||||||
|
subscribers: make(map[uint64]chan RequestEvent),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe registers a new subscriber and returns its ID and a receive-only channel.
|
||||||
|
func (b *RequestEventBus) Subscribe() (uint64, <-chan RequestEvent) {
|
||||||
|
id := b.nextID.Add(1)
|
||||||
|
ch := make(chan RequestEvent, requestEventBufSize)
|
||||||
|
b.mu.Lock()
|
||||||
|
b.subscribers[id] = ch
|
||||||
|
b.mu.Unlock()
|
||||||
|
return id, ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscriber and closes its channel.
|
||||||
|
func (b *RequestEventBus) Unsubscribe(id uint64) {
|
||||||
|
b.mu.Lock()
|
||||||
|
ch, ok := b.subscribers[id]
|
||||||
|
if ok {
|
||||||
|
delete(b.subscribers, id)
|
||||||
|
}
|
||||||
|
b.mu.Unlock()
|
||||||
|
if ok {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish sends an event to all current subscribers without blocking.
|
||||||
|
func (b *RequestEventBus) Publish(e RequestEvent) {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.mu.RLock()
|
||||||
|
defer b.mu.RUnlock()
|
||||||
|
for _, ch := range b.subscribers {
|
||||||
|
select {
|
||||||
|
case ch <- e:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
100
backend/internal/service/request_event_bus_test.go
Normal file
100
backend/internal/service/request_event_bus_test.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestEventBus_PublishToSubscriber(t *testing.T) {
|
||||||
|
bus := NewRequestEventBus()
|
||||||
|
|
||||||
|
id, ch := bus.Subscribe()
|
||||||
|
defer bus.Unsubscribe(id)
|
||||||
|
|
||||||
|
evt := RequestEvent{Model: "claude-3", Status: "success", LatencyMS: 100}
|
||||||
|
bus.Publish(evt)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-ch:
|
||||||
|
assert.Equal(t, evt, got)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEventBus_MultipleSubscribers(t *testing.T) {
|
||||||
|
bus := NewRequestEventBus()
|
||||||
|
|
||||||
|
id1, ch1 := bus.Subscribe()
|
||||||
|
id2, ch2 := bus.Subscribe()
|
||||||
|
defer bus.Unsubscribe(id1)
|
||||||
|
defer bus.Unsubscribe(id2)
|
||||||
|
|
||||||
|
evt := RequestEvent{Model: "claude-3", Status: "error"}
|
||||||
|
bus.Publish(evt)
|
||||||
|
|
||||||
|
for _, ch := range []<-chan RequestEvent{ch1, ch2} {
|
||||||
|
select {
|
||||||
|
case got := <-ch:
|
||||||
|
assert.Equal(t, evt, got)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for event on one subscriber")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEventBus_UnsubscribeClosesChannel(t *testing.T) {
|
||||||
|
bus := NewRequestEventBus()
|
||||||
|
id, ch := bus.Subscribe()
|
||||||
|
|
||||||
|
bus.Unsubscribe(id)
|
||||||
|
|
||||||
|
// Channel should be closed.
|
||||||
|
_, ok := <-ch
|
||||||
|
assert.False(t, ok, "channel should be closed after Unsubscribe")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEventBus_UnsubscribedMissesEvents(t *testing.T) {
|
||||||
|
bus := NewRequestEventBus()
|
||||||
|
id, _ := bus.Subscribe()
|
||||||
|
bus.Unsubscribe(id)
|
||||||
|
|
||||||
|
// Publish after unsubscribe should not panic.
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
bus.Publish(RequestEvent{Model: "test"})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEventBus_DropWhenFull(t *testing.T) {
|
||||||
|
bus := NewRequestEventBus()
|
||||||
|
id, ch := bus.Subscribe()
|
||||||
|
defer bus.Unsubscribe(id)
|
||||||
|
|
||||||
|
// Fill the buffer then publish one more — should drop, not block.
|
||||||
|
evt := RequestEvent{Model: "model", Status: "success"}
|
||||||
|
for i := 0; i < requestEventBufSize; i++ {
|
||||||
|
bus.Publish(evt)
|
||||||
|
}
|
||||||
|
// This publish should return immediately (dropped).
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
bus.Publish(evt)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Publish blocked when buffer was full")
|
||||||
|
}
|
||||||
|
assert.Len(t, ch, requestEventBufSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestEventBus_NilSafePublish(t *testing.T) {
|
||||||
|
var bus *RequestEventBus
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
bus.Publish(RequestEvent{Model: "test"})
|
||||||
|
})
|
||||||
|
}
|
||||||
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
|
||||||
|
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
|
||||||
|
|
||||||
|
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
|
||||||
|
// When an account's RPM budget is exhausted, callers can wait up to a configured
|
||||||
|
// deadline instead of receiving an immediate 429. The bucket refills continuously
|
||||||
|
// at rpm/60 tokens per second so requests are distributed evenly over time.
|
||||||
|
type RPMTokenBucketService struct {
|
||||||
|
buckets sync.Map // map[int64]*rpmEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
|
||||||
|
func NewRPMTokenBucketService() *RPMTokenBucketService {
|
||||||
|
return &RPMTokenBucketService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpmEntry struct {
|
||||||
|
bucket *tokenBucket
|
||||||
|
rpm int
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBucket returns (or creates) the token bucket for accountID.
|
||||||
|
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
|
||||||
|
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
|
||||||
|
if v, ok := s.buckets.Load(accountID); ok {
|
||||||
|
e := v.(*rpmEntry)
|
||||||
|
if e.rpm == rpm {
|
||||||
|
return e.bucket
|
||||||
|
}
|
||||||
|
// RPM limit changed — replace with a fresh bucket.
|
||||||
|
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||||
|
s.buckets.Store(accountID, fresh)
|
||||||
|
return fresh.bucket
|
||||||
|
}
|
||||||
|
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||||
|
actual, _ := s.buckets.LoadOrStore(accountID, entry)
|
||||||
|
return actual.(*rpmEntry).bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireWithWait attempts to consume one token for the given account.
|
||||||
|
// It blocks up to maxWait for a token to become available.
|
||||||
|
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
|
||||||
|
// or ctx.Err() if the context is cancelled.
|
||||||
|
// If rpm <= 0 the call returns immediately with nil.
|
||||||
|
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||||
|
if rpm <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bucket := s.getBucket(accountID, rpm)
|
||||||
|
deadline := time.Now().Add(maxWait)
|
||||||
|
|
||||||
|
for {
|
||||||
|
ok, waitDur := bucket.tryAcquire()
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining <= 0 || waitDur > remaining {
|
||||||
|
return ErrRPMWaitTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-time.After(waitDur):
|
||||||
|
// token may be available now; retry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenBucket is a continuous-refill token bucket for a single account.
|
||||||
|
type tokenBucket struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
tokens float64
|
||||||
|
maxTokens float64
|
||||||
|
rateSec float64 // tokens refilled per second = rpm / 60
|
||||||
|
lastFill time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokenBucket(rpm int) *tokenBucket {
|
||||||
|
max := float64(rpm)
|
||||||
|
return &tokenBucket{
|
||||||
|
tokens: max,
|
||||||
|
maxTokens: max,
|
||||||
|
rateSec: float64(rpm) / 60.0,
|
||||||
|
lastFill: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
|
||||||
|
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
|
||||||
|
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(b.lastFill).Seconds()
|
||||||
|
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
|
||||||
|
b.lastFill = now
|
||||||
|
|
||||||
|
if b.tokens >= 1.0 {
|
||||||
|
b.tokens -= 1.0
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
deficit := 1.0 - b.tokens
|
||||||
|
waitSecs := deficit / b.rateSec
|
||||||
|
return false, time.Duration(waitSecs * float64(time.Second))
|
||||||
|
}
|
||||||
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
ctx := context.Background()
|
||||||
|
// Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately.
|
||||||
|
for i := 0; i < 60; i++ {
|
||||||
|
err := svc.AcquireWithWait(ctx, 1, 60, 0)
|
||||||
|
require.NoError(t, err, "call %d should succeed immediately", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
err := svc.AcquireWithWait(context.Background(), 42, 0, 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// rpm=1 → 1 token/minute. One call drains the bucket.
|
||||||
|
err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second)
|
||||||
|
require.NoError(t, err, "first call should succeed")
|
||||||
|
|
||||||
|
// Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms.
|
||||||
|
start := time.Now()
|
||||||
|
err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
assert.ErrorIs(t, err, ErrRPMWaitTimeout)
|
||||||
|
assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully.
|
||||||
|
for i := 0; i < 120; i++ {
|
||||||
|
require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next call needs to wait ~500ms for the next token. Give it 2s.
|
||||||
|
start := time.Now()
|
||||||
|
err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.NoError(t, err, "should succeed after waiting for refill")
|
||||||
|
assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited")
|
||||||
|
assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_ContextCancellation(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
|
||||||
|
// rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining.
|
||||||
|
// maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms).
|
||||||
|
// Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first.
|
||||||
|
for i := 0; i < 120; i++ {
|
||||||
|
require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Drain account 1 (rpm=1).
|
||||||
|
require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0))
|
||||||
|
|
||||||
|
// Account 2 has its own bucket and should succeed immediately.
|
||||||
|
err := svc.AcquireWithWait(ctx, 2, 1, 0)
|
||||||
|
assert.NoError(t, err, "different account should have an independent bucket")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) {
|
||||||
|
svc := NewRPMTokenBucketService()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create bucket with rpm=1 and drain it.
|
||||||
|
require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0))
|
||||||
|
// Bucket now empty with rpm=1.
|
||||||
|
|
||||||
|
// Changing RPM to 60 should reset the bucket to full (60 tokens).
|
||||||
|
err := svc.AcquireWithWait(ctx, 10, 60, 0)
|
||||||
|
assert.NoError(t, err, "new RPM should cause bucket recreation")
|
||||||
|
}
|
||||||
@ -426,6 +426,8 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideBillingCacheService,
|
ProvideBillingCacheService,
|
||||||
NewAnnouncementService,
|
NewAnnouncementService,
|
||||||
NewAdminService,
|
NewAdminService,
|
||||||
|
NewRPMTokenBucketService,
|
||||||
|
NewRequestEventBus,
|
||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
NewOpenAIGatewayService,
|
NewOpenAIGatewayService,
|
||||||
NewOAuthService,
|
NewOAuthService,
|
||||||
@ -448,6 +450,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideSettingService,
|
ProvideSettingService,
|
||||||
NewDataManagementService,
|
NewDataManagementService,
|
||||||
ProvideBackupService,
|
ProvideBackupService,
|
||||||
|
NewHealthService,
|
||||||
ProvideOpsSystemLogSink,
|
ProvideOpsSystemLogSink,
|
||||||
NewOpsService,
|
NewOpsService,
|
||||||
ProvideOpsMetricsCollector,
|
ProvideOpsMetricsCollector,
|
||||||
|
|||||||
@ -120,7 +120,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- sub2api-network
|
- sub2api-network
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health" ]
|
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/ready"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user