feat: merge feat/omniroute-ideas — P2C scheduler, quota scoring, tier fallback
This commit is contained in:
commit
fdd2d08a4d
@ -188,7 +188,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
@ -203,7 +204,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
requestEventBus := service.NewRequestEventBus()
|
||||
opsHandler := admin.NewOpsHandler(opsService, requestEventBus)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||
@ -244,7 +246,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, windsurfGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
totpHandler := handler.NewTotpHandler(totpService)
|
||||
@ -257,7 +259,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
|
||||
healthService := service.NewHealthService(db, redisClient)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||
|
||||
@ -691,6 +691,14 @@ type GatewayConfig struct {
|
||||
// UserMessageQueue: 用户消息串行队列配置
|
||||
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
|
||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||
|
||||
// RPMSmoothing: RPM 令牌桶平滑配置
|
||||
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
||||
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
||||
|
||||
// ContextCompression: 主动上下文压缩配置
|
||||
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
|
||||
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
|
||||
}
|
||||
|
||||
type GatewayAntigravityLSWorkerConfig struct {
|
||||
@ -745,6 +753,37 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RPMSmoothingConfig RPM 令牌桶平滑配置
|
||||
type RPMSmoothingConfig struct {
|
||||
// Enabled: 是否启用 RPM 令牌桶平滑(默认 false)
|
||||
// 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429(默认 5000)
|
||||
MaxWaitMS int `mapstructure:"max_wait_ms"`
|
||||
}
|
||||
|
||||
// MaxWait returns the configured wait duration, defaulting to 5s.
|
||||
func (c *RPMSmoothingConfig) MaxWait() time.Duration {
|
||||
if c.MaxWaitMS <= 0 {
|
||||
return 5 * time.Second
|
||||
}
|
||||
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
||||
}
|
||||
|
||||
// ContextCompressionConfig 主动上下文压缩配置
|
||||
type ContextCompressionConfig struct {
|
||||
// MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000)
|
||||
MaxTokens int `mapstructure:"max_tokens"`
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
|
||||
func (c *ContextCompressionConfig) GetMaxTokens() int {
|
||||
if c.MaxTokens <= 0 {
|
||||
return 190_000
|
||||
}
|
||||
return c.MaxTokens
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
@ -828,6 +867,8 @@ type GatewayOpenAIWSConfig struct {
|
||||
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
|
||||
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
|
||||
|
||||
// EnableP2CScheduling: 启用 Power-of-Two-Choices 调度(默认 false,使用 top-K 加权随机)
|
||||
EnableP2CScheduling bool `mapstructure:"enable_p2c_scheduling"`
|
||||
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
|
||||
}
|
||||
|
||||
@ -838,6 +879,8 @@ type GatewayOpenAIWSSchedulerScoreWeights struct {
|
||||
Queue float64 `mapstructure:"queue"`
|
||||
ErrorRate float64 `mapstructure:"error_rate"`
|
||||
TTFT float64 `mapstructure:"ttft"`
|
||||
// Quota: 剩余配额比例权重(0 表示不参与打分)
|
||||
Quota float64 `mapstructure:"quota"`
|
||||
}
|
||||
|
||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||
@ -992,6 +1035,10 @@ type GatewaySchedulingConfig struct {
|
||||
// 全量重建周期配置
|
||||
// 全量重建周期(秒),0 表示禁用
|
||||
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
||||
|
||||
// EnableTierFallbackChain: 启用跨档降级链(订阅 → API Key → Bedrock),默认 false
|
||||
// 仅对 Anthropic 平台生效;启用后账号按类型分层,优先使用订阅账号,依次降级。
|
||||
EnableTierFallbackChain bool `mapstructure:"enable_tier_fallback_chain"`
|
||||
}
|
||||
|
||||
func (s *ServerConfig) Address() string {
|
||||
@ -2504,7 +2551,8 @@ func (c *Config) Validate() error {
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Quota < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
|
||||
}
|
||||
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
|
||||
|
||||
@ -16,7 +16,8 @@ import (
|
||||
)
|
||||
|
||||
type OpsHandler struct {
|
||||
opsService *service.OpsService
|
||||
opsService *service.OpsService
|
||||
requestEventBus *service.RequestEventBus
|
||||
}
|
||||
|
||||
// GetErrorLogByID returns ops error log detail.
|
||||
@ -70,8 +71,8 @@ func parseOpsViewParam(c *gin.Context) string {
|
||||
}
|
||||
}
|
||||
|
||||
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||
return &OpsHandler{opsService: opsService}
|
||||
func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler {
|
||||
return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus}
|
||||
}
|
||||
|
||||
// GetErrorLogs lists ops error logs.
|
||||
|
||||
@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService {
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||
@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
h := NewOpsHandler(newRuntimeOpsService(t), nil)
|
||||
r := newOpsRuntimeRouter(h, true)
|
||||
|
||||
payload := map[string]any{
|
||||
|
||||
@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
h := NewOpsHandler(nil, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||
sink := service.NewOpsSystemLogSink(nil)
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||
h := NewOpsHandler(svc)
|
||||
h := NewOpsHandler(svc, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
h := NewOpsHandler(nil, nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h = NewOpsHandler(svc)
|
||||
h = NewOpsHandler(svc, nil)
|
||||
r = newOpsSystemLogTestRouter(h, false)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
|
||||
198
backend/internal/handler/admin/ops_ws_requests_handler.go
Normal file
198
backend/internal/handler/admin/ops_ws_requests_handler.go
Normal file
@ -0,0 +1,198 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type requestStreamWSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data service.RequestEvent `json:"data"`
|
||||
}
|
||||
|
||||
// RequestStreamWSHandler streams real-time request events to WebSocket clients.
|
||||
// GET /api/v1/admin/ops/ws/requests
|
||||
//
|
||||
// Each connected client receives a JSON message per gateway dispatch:
|
||||
//
|
||||
// {"type":"request_event","data":{"timestamp":...,"method":"POST","path":"/v1/messages",
|
||||
// "model":"claude-3-5-sonnet-20241022","account_id":42,"status":"success","latency_ms":1230}}
|
||||
func (h *OpsHandler) RequestStreamWSHandler(c *gin.Context) {
|
||||
clientIP := requestClientIP(c.Request)
|
||||
|
||||
if h == nil || h.opsService == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
|
||||
return
|
||||
}
|
||||
if h.requestEventBus == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "request event bus not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
|
||||
return
|
||||
}
|
||||
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if wsConnCount.Add(-1) == 0 {
|
||||
scheduleQPSWSIdleStop()
|
||||
}
|
||||
}()
|
||||
|
||||
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] per-ip limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
defer releaseOpsWSIPSlot(clientIP)
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
handleRequestStreamWebSocket(c.Request.Context(), conn, h.requestEventBus)
|
||||
}
|
||||
|
||||
func handleRequestStreamWebSocket(parentCtx context.Context, conn *websocket.Conn, bus *service.RequestEventBus) {
|
||||
if conn == nil || bus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
defer cancel()
|
||||
|
||||
subID, eventCh := bus.Subscribe()
|
||||
defer bus.Unsubscribe(subID)
|
||||
|
||||
var closeOnce sync.Once
|
||||
closeConn := func() {
|
||||
closeOnce.Do(func() { _ = conn.Close() })
|
||||
}
|
||||
|
||||
closeFrameCh := make(chan []byte, 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
|
||||
})
|
||||
conn.SetCloseHandler(func(code int, text string) error {
|
||||
select {
|
||||
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
|
||||
default:
|
||||
}
|
||||
cancel()
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
pingTicker := time.NewTicker(qpsWSPingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
writeWithTimeout := func(messageType int, data []byte) error {
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
sendClose := func(closeFrame []byte) {
|
||||
if closeFrame == nil {
|
||||
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
|
||||
}
|
||||
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-eventCh:
|
||||
if !ok {
|
||||
// channel closed by Unsubscribe
|
||||
sendClose(nil)
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
msg, err := json.Marshal(requestStreamWSMessage{Type: "request_event", Data: evt})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] write failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] ping failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
case closeFrame := <-closeFrameCh:
|
||||
sendClose(closeFrame)
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
var closeFrame []byte
|
||||
select {
|
||||
case closeFrame = <-closeFrameCh:
|
||||
default:
|
||||
}
|
||||
sendClose(closeFrame)
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -48,6 +48,7 @@ type GatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
userMsgQueueHelper *UserMsgQueueHelper
|
||||
requestEventBus *service.RequestEventBus
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
cfg *config.Config
|
||||
@ -70,6 +71,7 @@ func NewGatewayHandler(
|
||||
userMsgQueueService *service.UserMessageQueueService,
|
||||
cfg *config.Config,
|
||||
settingService *service.SettingService,
|
||||
requestEventBus *service.RequestEventBus,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 10
|
||||
@ -103,6 +105,7 @@ func NewGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
userMsgQueueHelper: umqHelper,
|
||||
requestEventBus: requestEventBus,
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
cfg: cfg,
|
||||
@ -113,6 +116,7 @@ func NewGatewayHandler(
|
||||
// Messages handles Claude API compatible messages endpoint
|
||||
// POST /v1/messages
|
||||
func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqStartTime := time.Now()
|
||||
// 从context获取apiKey和user(ApiKeyAuth中间件已设置)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
@ -164,6 +168,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 解析渠道级模型映射
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
// 实时请求查看器:记录每次请求的结果(账号、模型、状态、延迟)
|
||||
var (
|
||||
reqEventAccountID int64
|
||||
reqEventStatus = "error"
|
||||
)
|
||||
defer func() {
|
||||
if h.requestEventBus != nil {
|
||||
h.requestEventBus.Publish(service.RequestEvent{
|
||||
Timestamp: reqStartTime,
|
||||
Method: c.Request.Method,
|
||||
Path: c.FullPath(),
|
||||
Model: reqModel,
|
||||
AccountID: reqEventAccountID,
|
||||
Status: reqEventStatus,
|
||||
LatencyMS: time.Since(reqStartTime).Milliseconds(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
@ -406,6 +429,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmCancel()
|
||||
if rpmErr != nil {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
reqEventAccountID = account.ID
|
||||
reqEventStatus = "rate_limited"
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
@ -473,6 +513,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 实时请求查看器:标记 Gemini 路径成功
|
||||
reqEventAccountID = account.ID
|
||||
reqEventStatus = "success"
|
||||
|
||||
// RPM 计数递增(Forward 成功后)
|
||||
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
||||
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
||||
@ -650,6 +694,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmCancel()
|
||||
if rpmErr != nil {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
reqEventAccountID = account.ID
|
||||
reqEventStatus = "rate_limited"
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
@ -850,6 +911,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 实时请求查看器:标记 Anthropic 路径成功
|
||||
reqEventAccountID = account.ID
|
||||
reqEventStatus = "success"
|
||||
|
||||
// RPM 计数递增(Forward 成功后)
|
||||
// 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
|
||||
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqljson"
|
||||
@ -49,6 +50,13 @@ type accountRepository struct {
|
||||
// Used to proactively sync account snapshot to cache when status changes,
|
||||
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||
schedulerCache service.SchedulerCache
|
||||
|
||||
// tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。
|
||||
// 上游 401/限流爆发时,N 个 in-flight 请求会同时调用此方法;底层 SQL
|
||||
// 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row,但 N 次
|
||||
// SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight
|
||||
// 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。
|
||||
tempUnschedSF singleflight.Group
|
||||
}
|
||||
|
||||
var schedulerNeutralExtraKeyPrefixes = []string{
|
||||
@ -1124,6 +1132,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
// 进程内合并并发调用:key 包含 until 和 reason,确保不同窗口/原因独立去重。
|
||||
// until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致;
|
||||
// 哪怕略有偏差,SQL 的 (existing < new) 条件保证语义安全。
|
||||
sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason
|
||||
_, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) {
|
||||
return nil, r.setTempUnschedulableOnce(ctx, id, until, reason)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = $1,
|
||||
|
||||
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
119
backend/internal/repository/account_repo_singleflight_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。
|
||||
// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在
|
||||
// singleflight 的同一窗口内。
|
||||
type blockingExecutor struct {
|
||||
mu sync.Mutex
|
||||
execCalls int32
|
||||
queryCalls int32
|
||||
release chan struct{}
|
||||
concurrent int32
|
||||
maxObserved int32
|
||||
}
|
||||
|
||||
func newBlockingExecutor() *blockingExecutor {
|
||||
return &blockingExecutor{release: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) Release() { close(e.release) }
|
||||
|
||||
func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||
atomic.AddInt32(&e.execCalls, 1)
|
||||
c := atomic.AddInt32(&e.concurrent, 1)
|
||||
for {
|
||||
old := atomic.LoadInt32(&e.maxObserved)
|
||||
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer atomic.AddInt32(&e.concurrent, -1)
|
||||
<-e.release
|
||||
return driverResult{}, nil
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) {
|
||||
atomic.AddInt32(&e.queryCalls, 1)
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
// driverResult 是一个零值 sql.Result,用于测试。
|
||||
type driverResult struct{}
|
||||
|
||||
func (driverResult) LastInsertId() (int64, error) { return 0, nil }
|
||||
func (driverResult) RowsAffected() (int64, error) { return 1, nil }
|
||||
|
||||
func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||
// 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际
|
||||
// SQL 路径(UPDATE + outbox INSERT = 2 次 ExecContext)。
|
||||
exec := newBlockingExecutor()
|
||||
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||
|
||||
const callers = 30
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
const reason = "OAuth 401: invalid_grant"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
for i := 0; i < callers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = repo.SetTempUnschedulable(context.Background(), 42, until, reason)
|
||||
}()
|
||||
}
|
||||
|
||||
// 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent),
|
||||
"singleflight should serialize the SQL call to exactly one in-flight execution")
|
||||
|
||||
exec.Release()
|
||||
wg.Wait()
|
||||
|
||||
// 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec;其余 29 个 caller 共享结果。
|
||||
require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2),
|
||||
"expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved),
|
||||
"no two SQL execs should run concurrently for the same singleflight key")
|
||||
}
|
||||
|
||||
func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) {
|
||||
// 不同 account 应分属不同 sf key,能并行写库。
|
||||
exec := newBlockingExecutor()
|
||||
repo := newAccountRepositoryWithSQL(nil, exec, nil)
|
||||
|
||||
until := time.Now().Add(10 * time.Minute)
|
||||
var wg sync.WaitGroup
|
||||
for i := int64(1); i <= 3; i++ {
|
||||
i := i
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason")
|
||||
}()
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved),
|
||||
"different accounts should be able to write in parallel")
|
||||
|
||||
exec.Release()
|
||||
wg.Wait()
|
||||
}
|
||||
@ -38,6 +38,7 @@ func ProvideRouter(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
healthService *service.HealthService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
@ -95,7 +96,7 @@ func ProvideRouter(
|
||||
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
|
||||
})
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@ -30,6 +30,7 @@ func SetupRouter(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
healthService *service.HealthService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
@ -81,7 +82,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@ -97,11 +98,12 @@ func registerRoutes(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
healthService *service.HealthService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
routes.RegisterCommonRoutes(r, healthService)
|
||||
|
||||
// API v1
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
@ -151,10 +151,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
|
||||
}
|
||||
|
||||
// WebSocket realtime (QPS/TPS)
|
||||
// WebSocket realtime (QPS/TPS and request stream)
|
||||
ws := ops.Group("/ws")
|
||||
{
|
||||
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||
ws.GET("/requests", h.Admin.Ops.RequestStreamWSHandler)
|
||||
}
|
||||
|
||||
// Error logs (legacy)
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@ -16,11 +17,37 @@ const (
|
||||
claudeCodeGrowthBookDateUpdated = "1970-01-01T00:00:00Z"
|
||||
)
|
||||
|
||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
||||
func RegisterCommonRoutes(r *gin.Engine) {
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
// readinessHandlerTimeout 限定 readiness 端点对外的最大返回耗时。
|
||||
// HealthService 内部对每个组件再有独立超时,所以这里给宽一点即可。
|
||||
const readinessHandlerTimeout = 3 * time.Second
|
||||
|
||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)。
|
||||
//
|
||||
// 健康端点的语义分层:
|
||||
// - /healthz : liveness 探针。零依赖、永远 200。容器/进程探活专用。
|
||||
// - /ready : readiness 探针。检查 DB+Redis;任一失败返回 503。
|
||||
// - /health : 历史端点,等价于 /healthz,保留向后兼容。
|
||||
//
|
||||
// dashboard 用的"业务健康分"由 ops_health_score 单独提供,与本路由无关。
|
||||
func RegisterCommonRoutes(r *gin.Engine, healthService *service.HealthService) {
|
||||
// Liveness:仅证明进程在响应。
|
||||
livenessHandler := func(c *gin.Context) {
|
||||
_ = healthService.Liveness()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
r.GET("/healthz", livenessHandler)
|
||||
r.GET("/health", livenessHandler) // 向后兼容旧的 docker-compose healthcheck
|
||||
|
||||
// Readiness:检查关键依赖。失败时返回 503 但仍带详情,便于排障。
|
||||
r.GET("/ready", func(c *gin.Context) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), readinessHandlerTimeout)
|
||||
defer cancel()
|
||||
report := healthService.Readiness(ctx)
|
||||
status := http.StatusOK
|
||||
if !report.OK {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
c.JSON(status, report)
|
||||
})
|
||||
|
||||
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
|
||||
|
||||
@ -7,16 +7,56 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newCommonRoutesTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterCommonRoutes(r)
|
||||
RegisterCommonRoutes(r, service.NewHealthService(nil, nil))
|
||||
return r
|
||||
}
|
||||
|
||||
func newTestRouter(t *testing.T, hs *service.HealthService) *gin.Engine {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
RegisterCommonRoutes(r, hs)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCommonRoutes_LivenessEndpoints(t *testing.T) {
|
||||
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||
for _, path := range []string{"/healthz", "/health"} {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code, "liveness path %s should be 200", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommonRoutes_ReadyEndpoint_NoDepsReturnsOK(t *testing.T) {
|
||||
// 没有 DB/Redis 依赖时 readiness 视为 ok(早期启动场景)。
|
||||
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||
req := httptest.NewRequest(http.MethodGet, "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Contains(t, w.Body.String(), "\"ok\":true")
|
||||
}
|
||||
|
||||
func TestCommonRoutes_SetupStatusUnchanged(t *testing.T) {
|
||||
// 验证我们没有破坏既有的 /setup/status 行为(前端依赖)。
|
||||
r := newTestRouter(t, service.NewHealthService(nil, nil))
|
||||
req := httptest.NewRequest(http.MethodGet, "/setup/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Contains(t, w.Body.String(), "needs_setup")
|
||||
}
|
||||
|
||||
func TestCommonRoutes_ClaudeCodeBootstrap(t *testing.T) {
|
||||
r := newCommonRoutesTestRouter()
|
||||
|
||||
|
||||
@ -988,6 +988,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsContextCompressionEnabled returns true if the account has opted into proactive
|
||||
// context compression. When enabled, the gateway will trim oldest messages before
|
||||
// dispatch to keep the estimated token count within the configured budget.
|
||||
func (a *Account) IsContextCompressionEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
enabled, _ := a.Credentials["enable_context_compression"].(bool)
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
@ -1572,6 +1583,24 @@ func (a *Account) GetQuotaUsed() float64 {
|
||||
return a.getExtraFloat64("quota_used")
|
||||
}
|
||||
|
||||
// GetQuotaRemainingFraction returns the fraction of total quota remaining in [0,1].
|
||||
// Returns 1.0 when no quota limit is set (limit == 0 means unlimited).
|
||||
func (a *Account) GetQuotaRemainingFraction() float64 {
|
||||
limit := a.GetQuotaLimit()
|
||||
if limit <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
used := a.GetQuotaUsed()
|
||||
remaining := (limit - used) / limit
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
if remaining > 1 {
|
||||
return 1
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用
|
||||
func (a *Account) GetQuotaDailyLimit() float64 {
|
||||
return a.getExtraFloat64("quota_daily_limit")
|
||||
|
||||
@ -40,6 +40,7 @@ func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{"标准版本号", "claude-cli/1.0.0", true},
|
||||
{"官方 transport UA", "claude-code/2.1.88", true},
|
||||
{"多位版本号", "claude-cli/12.34.56", true},
|
||||
{"大写开头", "Claude-CLI/1.0.0", true},
|
||||
{"非 claude-cli", "curl/7.64.1", false},
|
||||
@ -90,6 +91,19 @@ func TestValidate_MessagesPath_FullValid(t *testing.T) {
|
||||
require.True(t, result, "完整有效请求应通过")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_FullValid_ClaudeCodeUA(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-code/2.1.88")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
result := v.Validate(req, validClaudeCodeBody())
|
||||
require.True(t, result, "官方 transport/helper UA 也应通过")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
body := validClaudeCodeBody()
|
||||
|
||||
@ -15,11 +15,13 @@ import (
|
||||
type ClaudeCodeValidator struct{}
|
||||
|
||||
var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
// User-Agent 匹配: 官方 Claude Code 目前存在两类产品前缀:
|
||||
// 1. 主 Anthropic API 客户端: claude-cli/x.y.z (...)
|
||||
// 2. transport / helper 请求: claude-code/x.y.z
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/\d+\.\d+\.\d+`)
|
||||
|
||||
// 带捕获组的版本提取正则
|
||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/(\d+\.\d+\.\d+)`)
|
||||
|
||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||
systemPromptThreshold = 0.5
|
||||
@ -55,7 +57,7 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
// Validate 验证请求是否来自 Claude Code CLI
|
||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是官方 claude-cli/ 或 claude-code/ 前缀
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
||||
// Step 4: 对于 messages 路径,进行严格验证:
|
||||
|
||||
@ -64,6 +64,7 @@ func TestExtractVersion(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
|
||||
{"claude-code/2.1.88", "2.1.88"},
|
||||
{"claude-cli/1.0.0", "1.0.0"},
|
||||
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
|
||||
{"curl/8.0.0", ""}, // 非 Claude CLI
|
||||
|
||||
151
backend/internal/service/context_compressor.go
Normal file
151
backend/internal/service/context_compressor.go
Normal file
@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
|
||||
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
|
||||
const defaultContextCompressionMaxTokens = 190_000
|
||||
|
||||
// approxTokens estimates the token count for a string using the chars/4 heuristic.
|
||||
func approxTokens(s string) int {
|
||||
return int(math.Ceil(float64(len(s)) / 4.0))
|
||||
}
|
||||
|
||||
// compressMessagesInBody trims the oldest messages from the request body so that the
|
||||
// estimated token count of the messages array fits within maxTokens.
|
||||
// Returns the original body unchanged if no compression is needed or if parsing fails.
|
||||
func compressMessagesInBody(body []byte, maxTokens int) []byte {
|
||||
msgsResult := gjson.GetBytes(body, "messages")
|
||||
if !msgsResult.Exists() || !msgsResult.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
// Unmarshal to a typed slice for processing.
|
||||
var messages []map[string]any
|
||||
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
compressed, changed := compressMessages(messages, maxTokens)
|
||||
if !changed {
|
||||
return body
|
||||
}
|
||||
|
||||
newMsgs, err := json.Marshal(compressed)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// compressMessages removes the oldest messages from the front of msgs until the
|
||||
// estimated total token count is at or below maxTokens.
|
||||
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
|
||||
// to avoid orphaned tool_result blocks.
|
||||
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
|
||||
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
|
||||
if len(msgs) == 0 {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
// Estimate total tokens.
|
||||
totalTokens := 0
|
||||
for _, m := range msgs {
|
||||
totalTokens += msgTokens(m)
|
||||
}
|
||||
if totalTokens <= maxTokens {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
|
||||
type unit struct {
|
||||
startIdx int
|
||||
endIdx int // exclusive
|
||||
tokens int
|
||||
}
|
||||
units := make([]unit, 0, len(msgs))
|
||||
i := 0
|
||||
for i < len(msgs) {
|
||||
toks := msgTokens(msgs[i])
|
||||
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
|
||||
toks += msgTokens(msgs[i+1])
|
||||
units = append(units, unit{i, i + 2, toks})
|
||||
i += 2
|
||||
} else {
|
||||
units = append(units, unit{i, i + 1, toks})
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// Remove units from the front until we are within budget.
|
||||
// Always keep at least the last unit so we never send an empty messages array.
|
||||
removeCount := 0
|
||||
for removeCount < len(units)-1 && totalTokens > maxTokens {
|
||||
totalTokens -= units[removeCount].tokens
|
||||
removeCount++
|
||||
}
|
||||
if removeCount == 0 {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
cutIdx := units[removeCount].startIdx
|
||||
return msgs[cutIdx:], true
|
||||
}
|
||||
|
||||
// msgTokens estimates token count for a single message using the chars/4 heuristic.
|
||||
func msgTokens(msg map[string]any) int {
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return approxTokens(string(b))
|
||||
}
|
||||
|
||||
// isAssistantWithToolUse returns true if msg is an assistant message whose content
|
||||
// contains at least one block with "type": "tool_use".
|
||||
func isAssistantWithToolUse(msg map[string]any) bool {
|
||||
role, _ := msg["role"].(string)
|
||||
if role != "assistant" {
|
||||
return false
|
||||
}
|
||||
return contentContainsType(msg["content"], "tool_use")
|
||||
}
|
||||
|
||||
// isUserWithToolResult returns true if msg is a user message whose content
|
||||
// contains at least one block with "type": "tool_result".
|
||||
func isUserWithToolResult(msg map[string]any) bool {
|
||||
role, _ := msg["role"].(string)
|
||||
if role != "user" {
|
||||
return false
|
||||
}
|
||||
return contentContainsType(msg["content"], "tool_result")
|
||||
}
|
||||
|
||||
// contentContainsType returns true if content (a []any of blocks) contains a block
|
||||
// whose "type" field equals blockType.
|
||||
func contentContainsType(content any, blockType string) bool {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, b := range blocks {
|
||||
block, ok := b.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if t, _ := block["type"].(string); t == blockType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
195
backend/internal/service/context_compressor_test.go
Normal file
195
backend/internal/service/context_compressor_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// helpers
|
||||
|
||||
func makeMsg(role, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
}
|
||||
}
|
||||
|
||||
func makeToolUseMsg(id string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": id,
|
||||
"name": "search",
|
||||
"input": map[string]any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeToolResultMsg(toolUseID string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": "result text",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toAnySlice(msgs []map[string]any) []any {
|
||||
out := make([]any, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = m
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
|
||||
require.NoError(t, err)
|
||||
return b
|
||||
}
|
||||
|
||||
// tests
|
||||
|
||||
func TestApproxTokens(t *testing.T) {
|
||||
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
|
||||
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
|
||||
assert.Equal(t, 0, approxTokens(""))
|
||||
}
|
||||
|
||||
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "hi"),
|
||||
makeMsg("assistant", "hello"),
|
||||
}
|
||||
result, changed := compressMessages(msgs, 100_000)
|
||||
assert.False(t, changed)
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
|
||||
// 10 messages, each large enough to be over a tight budget when combined.
|
||||
msgs := make([]map[string]any, 10)
|
||||
for i := range msgs {
|
||||
role := "user"
|
||||
if i%2 == 1 {
|
||||
role = "assistant"
|
||||
}
|
||||
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
|
||||
}
|
||||
|
||||
// Force compression by using a very small token budget.
|
||||
result, changed := compressMessages(msgs, 1)
|
||||
assert.True(t, changed)
|
||||
// Must keep at least one message (the last).
|
||||
assert.GreaterOrEqual(t, len(result), 1)
|
||||
// The remaining messages should be from the tail (newest).
|
||||
lastOrig := msgs[len(msgs)-1]["content"]
|
||||
lastResult := result[len(result)-1]["content"]
|
||||
assert.Equal(t, lastOrig, lastResult)
|
||||
}
|
||||
|
||||
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
|
||||
// Messages: user → assistant+tool_use → user+tool_result → assistant
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "start"),
|
||||
makeToolUseMsg("tool-1"),
|
||||
makeToolResultMsg("tool-1"),
|
||||
makeMsg("assistant", "done"),
|
||||
}
|
||||
|
||||
// Budget that forces removal of the first non-paired message but keeps the tool pair.
|
||||
// Estimate total tokens and set budget to force removing only "start" but not the pair.
|
||||
total := 0
|
||||
for _, m := range msgs {
|
||||
total += msgTokens(m)
|
||||
}
|
||||
// Budget: remove "start" but keep tool pair + "done".
|
||||
startTokens := msgTokens(msgs[0])
|
||||
budget := total - startTokens
|
||||
|
||||
result, changed := compressMessages(msgs, budget)
|
||||
assert.True(t, changed)
|
||||
|
||||
// tool_use and tool_result should both be present or both absent.
|
||||
hasToolUse := false
|
||||
hasToolResult := false
|
||||
for _, m := range result {
|
||||
if isAssistantWithToolUse(m) {
|
||||
hasToolUse = true
|
||||
}
|
||||
if isUserWithToolResult(m) {
|
||||
hasToolResult = true
|
||||
}
|
||||
}
|
||||
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
|
||||
}
|
||||
|
||||
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
|
||||
// Budget forces removal of the tool pair.
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "start"),
|
||||
makeToolUseMsg("tool-1"),
|
||||
makeToolResultMsg("tool-1"),
|
||||
makeMsg("assistant", "final answer after tool use"),
|
||||
}
|
||||
|
||||
// Budget: only keep the last "assistant" message.
|
||||
lastTokens := msgTokens(msgs[len(msgs)-1])
|
||||
|
||||
result, changed := compressMessages(msgs, lastTokens)
|
||||
assert.True(t, changed)
|
||||
|
||||
// Neither tool_use nor tool_result should remain.
|
||||
for _, m := range result {
|
||||
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
|
||||
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
|
||||
result := compressMessagesInBody(body, 1)
|
||||
assert.Equal(t, body, result, "body without messages should be unchanged")
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
|
||||
msgs := []map[string]any{makeMsg("user", "hi")}
|
||||
body := bodyWithMessages(t, msgs)
|
||||
result := compressMessagesInBody(body, 100_000)
|
||||
assert.Equal(t, body, result, "body under budget should be unchanged")
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
|
||||
msgs := make([]map[string]any, 20)
|
||||
for i := range msgs {
|
||||
role := "user"
|
||||
if i%2 == 1 {
|
||||
role = "assistant"
|
||||
}
|
||||
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
|
||||
}
|
||||
body := bodyWithMessages(t, msgs)
|
||||
|
||||
// Force significant compression.
|
||||
result := compressMessagesInBody(body, 50)
|
||||
assert.Less(t, len(result), len(body), "compressed body should be smaller")
|
||||
|
||||
// Resulting body should still be valid JSON with a messages array.
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||
resultMsgs, ok := parsed["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
|
||||
}
|
||||
@ -689,6 +689,41 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
|
||||
require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
sessionID := "12345678-1234-1234-1234-123456789abc"
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"metadata": map[string]any{
|
||||
"user_id": FormatMetadataUserID(
|
||||
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
|
||||
"",
|
||||
sessionID,
|
||||
claude.DefaultCLIProductVersion,
|
||||
),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||
},
|
||||
}
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
req, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-3-7-sonnet-20250219", false, false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sessionID, getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"))
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -335,8 +335,9 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`)
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@ -569,7 +570,8 @@ type GatewayService struct {
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
@ -614,6 +616,7 @@ func NewGatewayService(
|
||||
channelService *ChannelService,
|
||||
resolver *ModelPricingResolver,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
rpmTokenBucketSvc *RPMTokenBucketService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@ -640,6 +643,7 @@ func NewGatewayService(
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
rpmTokenBucket: rpmTokenBucketSvc,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
@ -1374,6 +1378,9 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||
if platform == PlatformAnthropic && s.enableTierFallbackChain() {
|
||||
return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -2558,6 +2565,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
|
||||
return err
|
||||
}
|
||||
|
||||
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
|
||||
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
|
||||
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||
if s.rpmTokenBucket == nil {
|
||||
return nil
|
||||
}
|
||||
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
|
||||
}
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||
@ -3765,6 +3781,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
return ParseMetadataUserID(metadataUserID) != nil
|
||||
}
|
||||
|
||||
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return true
|
||||
}
|
||||
if parsed == nil || c == nil {
|
||||
return false
|
||||
}
|
||||
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||
}
|
||||
|
||||
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
|
||||
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
|
||||
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
|
||||
@ -4323,6 +4349,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
|
||||
if account.IsContextCompressionEnabled() {
|
||||
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
|
||||
body = compressMessagesInBody(body, maxTok)
|
||||
}
|
||||
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
@ -5923,15 +5955,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// X-Claude-Code-Session-Id 头处理:
|
||||
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id
|
||||
// 2. 客户端未提供(mimic 模式)→ 生成确定性 per-account session UUID
|
||||
// 真实 CLI 每个请求都携带此 header(per-process UUID)。
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
// Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
|
||||
// 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -8969,12 +8997,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
// Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。
|
||||
// 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
133
backend/internal/service/gateway_tier_fallback.go
Normal file
133
backend/internal/service/gateway_tier_fallback.go
Normal file
@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// accountTierLevel maps an account type to a scheduling tier:
|
||||
//
|
||||
// 0 = subscription (OAuth / SetupToken) — tried first
|
||||
// 1 = API Key — first fallback
|
||||
// 2 = Bedrock — last resort
|
||||
//
|
||||
// Accounts with an unknown type fall into tier 0 so they participate in the
|
||||
// primary selection and do not vanish silently.
|
||||
func accountTierLevel(account *Account) int {
|
||||
if account == nil {
|
||||
return 0
|
||||
}
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
return 1
|
||||
case AccountTypeBedrock:
|
||||
return 2
|
||||
default: // OAuth, SetupToken, or unknown
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// enableTierFallbackChain reports whether the cross-tier fallback chain is
|
||||
// enabled in config (default false).
|
||||
func (s *GatewayService) enableTierFallbackChain() bool {
|
||||
return s != nil && s.cfg != nil && s.cfg.Gateway.Scheduling.EnableTierFallbackChain
|
||||
}
|
||||
|
||||
// selectAccountWithTierFallback tries Anthropic accounts in tier order:
|
||||
// tier 0 (OAuth/SetupToken subscription) → tier 1 (API Key) → tier 2 (Bedrock).
|
||||
//
|
||||
// Sticky sessions are honored within the chain: if the session-bound account is
|
||||
// in a tier that still has capacity it is returned immediately; otherwise the
|
||||
// session binding is cleared and the chain proceeds from tier 0.
|
||||
func (s *GatewayService) selectAccountWithTierFallback(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
) (*Account, error) {
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAnthropic, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// Build per-tier candidate lists (pointers into `accounts`).
|
||||
const numTiers = 3
|
||||
tierCandidates := [numTiers][]*Account{}
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.Platform != PlatformAnthropic {
|
||||
continue
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForQuota(acc) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
tier := accountTierLevel(acc)
|
||||
if tier < numTiers {
|
||||
tierCandidates[tier] = append(tierCandidates[tier], acc)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
selectionMode := cfg.FallbackSelectionMode
|
||||
|
||||
// Check sticky session: if the bound account is a valid candidate, use it.
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, cacheErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if cacheErr == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
for tier := 0; tier < numTiers; tier++ {
|
||||
for _, acc := range tierCandidates[tier] {
|
||||
if acc.ID != accountID {
|
||||
continue
|
||||
}
|
||||
if shouldClearStickySession(acc, requestedModel) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
break
|
||||
}
|
||||
if s.isAccountSchedulableForWindowCost(ctx, acc, true) &&
|
||||
s.isAccountSchedulableForRPM(ctx, acc, true) {
|
||||
return acc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try each tier in order.
|
||||
for tier := 0; tier < numTiers; tier++ {
|
||||
candidates := tierCandidates[tier]
|
||||
if len(candidates) == 0 {
|
||||
continue
|
||||
}
|
||||
s.sortCandidatesForFallback(candidates, false, selectionMode)
|
||||
result, acquired, _ := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, false)
|
||||
if acquired && result != nil {
|
||||
return result.Account, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no available accounts in any tier")
|
||||
}
|
||||
138
backend/internal/service/gateway_tier_fallback_test.go
Normal file
138
backend/internal/service/gateway_tier_fallback_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountTierLevel(t *testing.T) {
|
||||
require.Equal(t, 0, accountTierLevel(nil))
|
||||
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeOAuth}))
|
||||
require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeSetupToken}))
|
||||
require.Equal(t, 0, accountTierLevel(&Account{Type: "unknown"}))
|
||||
require.Equal(t, 1, accountTierLevel(&Account{Type: AccountTypeAPIKey}))
|
||||
require.Equal(t, 2, accountTierLevel(&Account{Type: AccountTypeBedrock}))
|
||||
}
|
||||
|
||||
func TestGatewayService_EnableTierFallbackChain(t *testing.T) {
|
||||
require.False(t, (*GatewayService)(nil).enableTierFallbackChain())
|
||||
require.False(t, (&GatewayService{}).enableTierFallbackChain())
|
||||
|
||||
cfgOff := &config.Config{}
|
||||
cfgOff.Gateway.Scheduling.EnableTierFallbackChain = false
|
||||
require.False(t, (&GatewayService{cfg: cfgOff}).enableTierFallbackChain())
|
||||
|
||||
cfgOn := &config.Config{}
|
||||
cfgOn.Gateway.Scheduling.EnableTierFallbackChain = true
|
||||
require.True(t, (&GatewayService{cfg: cfgOn}).enableTierFallbackChain())
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription verifies
|
||||
// that when both OAuth (subscription) and APIKey accounts are available, the
|
||||
// tier-0 OAuth account is always selected first even if APIKey has higher priority.
|
||||
func TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
oauthAcc := Account{ID: 91001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Priority: 5}
|
||||
apiKeyAcc := Account{ID: 91002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Priority: 0}
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||
accountsByID: map[int64]*Account{91001: &oauthAcc, 91002: &apiKeyAcc},
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||
|
||||
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(91001), acc.ID, "OAuth (tier-0) account should be preferred over APIKey (tier-1)")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey verifies
|
||||
// that when the subscription tier has no schedulable accounts, the fallback
|
||||
// selects an API Key account.
|
||||
func TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
oauthAcc := Account{ID: 92001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||
apiKeyAcc := Account{ID: 92002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||
accountsByID: map[int64]*Account{92001: &oauthAcc, 92002: &apiKeyAcc},
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||
|
||||
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(92002), acc.ID, "Should fall back to APIKey when OAuth is rate-limited")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts ensures
|
||||
// excluded IDs are respected across all tiers.
|
||||
func TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
oauthAcc := Account{ID: 93001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}
|
||||
apiKeyAcc := Account{ID: 93002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true}
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{oauthAcc, apiKeyAcc},
|
||||
accountsByID: map[int64]*Account{93001: &oauthAcc, 93002: &apiKeyAcc},
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||
|
||||
excluded := map[int64]struct{}{93001: {}}
|
||||
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", excluded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(93002), acc.ID, "Excluded OAuth account should cause APIKey fallback")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithTierFallback_NoAccounts verifies that
|
||||
// an error is returned when all tiers are empty.
|
||||
func TestGatewayService_SelectAccountWithTierFallback_NoAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{accounts: nil, accountsByID: map[int64]*Account{}}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||
|
||||
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort verifies
|
||||
// that Bedrock accounts are only used when subscription and API Key tiers are exhausted.
|
||||
func TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(30 * time.Minute)
|
||||
oauthAcc := Account{ID: 94001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||
apiKeyAcc := Account{ID: 94002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil}
|
||||
bedrockAcc := Account{ID: 94003, Platform: PlatformAnthropic, Type: AccountTypeBedrock, Status: StatusActive, Schedulable: true}
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{oauthAcc, apiKeyAcc, bedrockAcc},
|
||||
accountsByID: map[int64]*Account{94001: &oauthAcc, 94002: &apiKeyAcc, 94003: &bedrockAcc},
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()}
|
||||
|
||||
acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(94003), acc.ID, "Bedrock should be selected as last resort")
|
||||
}
|
||||
119
backend/internal/service/health_service.go
Normal file
119
backend/internal/service/health_service.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Package service - HealthService 提供 liveness 与 readiness 探针。
|
||||
//
|
||||
// 设计动机:原有 /health 端点既被 docker-compose healthcheck 使用,又被
|
||||
// dashboard 的 ops_health_score 复用——后者会触发 DB/Redis 等重操作,
|
||||
// 导致探活流量污染监控指标。本服务把两类语义拆开:
|
||||
// - Liveness : 仅证明进程存活(无外部依赖检查)。
|
||||
// - Readiness : 检查 DB + Redis 连通,作为是否可接收流量的判断。
|
||||
//
|
||||
// dashboard 维度的"业务健康分"仍由 ops_health_score 计算,与本服务无关。
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 探针默认超时。Readiness 探针需要快速失败,避免堆积。
|
||||
const (
|
||||
defaultReadinessTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// ReadinessReport 描述各依赖项的状态,便于上层暴露细节给排障。
|
||||
type ReadinessReport struct {
|
||||
OK bool `json:"ok"`
|
||||
Details map[string]ComponentStatus `json:"details"`
|
||||
Elapsed time.Duration `json:"elapsed_ms"`
|
||||
}
|
||||
|
||||
// ComponentStatus 单个依赖项的状态。Error 字段在 OK=true 时为空。
|
||||
type ComponentStatus struct {
|
||||
OK bool `json:"ok"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Elapsed string `json:"elapsed,omitempty"`
|
||||
}
|
||||
|
||||
// HealthService 提供 liveness/readiness 探针。
|
||||
// 字段都允许为 nil:缺失的依赖在 readiness 中自动跳过,便于测试和分阶段启用。
|
||||
type HealthService struct {
|
||||
db *sql.DB
|
||||
rdb *redis.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewHealthService 构造函数。timeout<=0 时使用默认值。
|
||||
func NewHealthService(db *sql.DB, rdb *redis.Client) *HealthService {
|
||||
return &HealthService{
|
||||
db: db,
|
||||
rdb: rdb,
|
||||
timeout: defaultReadinessTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Liveness 仅返回 nil。任何调用方能拿到这个返回值就说明进程在响应请求。
|
||||
// 保持无副作用、零依赖,便于 K8s livenessProbe 高频调用。
|
||||
func (s *HealthService) Liveness() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Readiness 检查所有外部依赖。任一失败则整体 OK=false。
|
||||
// 单个依赖的 ctx 超时由 timeout 控制,独立计时不互相阻塞。
|
||||
func (s *HealthService) Readiness(ctx context.Context) ReadinessReport {
|
||||
start := time.Now()
|
||||
report := ReadinessReport{
|
||||
OK: true,
|
||||
Details: make(map[string]ComponentStatus, 2),
|
||||
}
|
||||
|
||||
if s.db != nil {
|
||||
report.Details["database"] = s.checkDB(ctx)
|
||||
if !report.Details["database"].OK {
|
||||
report.OK = false
|
||||
}
|
||||
}
|
||||
if s.rdb != nil {
|
||||
report.Details["redis"] = s.checkRedis(ctx)
|
||||
if !report.Details["redis"].OK {
|
||||
report.OK = false
|
||||
}
|
||||
}
|
||||
|
||||
report.Elapsed = time.Since(start)
|
||||
return report
|
||||
}
|
||||
|
||||
func (s *HealthService) checkDB(parent context.Context) ComponentStatus {
|
||||
ctx, cancel := context.WithTimeout(parent, s.timeout)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
err := s.db.PingContext(ctx)
|
||||
status := ComponentStatus{Elapsed: time.Since(start).String()}
|
||||
if err != nil {
|
||||
status.Error = err.Error()
|
||||
return status
|
||||
}
|
||||
status.OK = true
|
||||
return status
|
||||
}
|
||||
|
||||
func (s *HealthService) checkRedis(parent context.Context) ComponentStatus {
|
||||
ctx, cancel := context.WithTimeout(parent, s.timeout)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
pong, err := s.rdb.Ping(ctx).Result()
|
||||
status := ComponentStatus{Elapsed: time.Since(start).String()}
|
||||
if err != nil {
|
||||
status.Error = err.Error()
|
||||
return status
|
||||
}
|
||||
if pong != "PONG" {
|
||||
status.Error = errors.New("unexpected redis ping response: " + pong).Error()
|
||||
return status
|
||||
}
|
||||
status.OK = true
|
||||
return status
|
||||
}
|
||||
93
backend/internal/service/health_service_test.go
Normal file
93
backend/internal/service/health_service_test.go
Normal file
@ -0,0 +1,93 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHealthService_Liveness_AlwaysOK(t *testing.T) {
|
||||
s := NewHealthService(nil, nil)
|
||||
require.NoError(t, s.Liveness())
|
||||
}
|
||||
|
||||
func TestHealthService_Readiness_AllNilReturnsOK(t *testing.T) {
|
||||
// 当所有依赖都为 nil 时(早期启动或 unit test),readiness 应直接 OK。
|
||||
s := NewHealthService(nil, nil)
|
||||
report := s.Readiness(context.Background())
|
||||
require.True(t, report.OK)
|
||||
require.Empty(t, report.Details)
|
||||
}
|
||||
|
||||
func TestHealthService_Readiness_DBPingFails(t *testing.T) {
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectPing().WillReturnError(errors.New("connection refused"))
|
||||
|
||||
s := NewHealthService(db, nil)
|
||||
report := s.Readiness(context.Background())
|
||||
require.False(t, report.OK)
|
||||
require.Contains(t, report.Details, "database")
|
||||
require.False(t, report.Details["database"].OK)
|
||||
require.Contains(t, report.Details["database"].Error, "connection refused")
|
||||
}
|
||||
|
||||
func TestHealthService_Readiness_DBOK(t *testing.T) {
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectPing()
|
||||
|
||||
s := NewHealthService(db, nil)
|
||||
report := s.Readiness(context.Background())
|
||||
require.True(t, report.OK)
|
||||
require.True(t, report.Details["database"].OK)
|
||||
}
|
||||
|
||||
func TestHealthService_Readiness_RedisFails(t *testing.T) {
|
||||
// 指向一个不可达端口让 redis ping 立刻失败。
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 200 * time.Millisecond,
|
||||
ReadTimeout: 200 * time.Millisecond,
|
||||
})
|
||||
defer rdb.Close()
|
||||
|
||||
s := NewHealthService(nil, rdb)
|
||||
s.timeout = 500 * time.Millisecond
|
||||
report := s.Readiness(context.Background())
|
||||
require.False(t, report.OK)
|
||||
require.Contains(t, report.Details, "redis")
|
||||
require.False(t, report.Details["redis"].OK)
|
||||
}
|
||||
|
||||
func TestHealthService_Readiness_PerComponentTimeout(t *testing.T) {
|
||||
// 验证 readiness 在超时时不会无限挂住。
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
mock.ExpectPing().WillDelayFor(2 * time.Second)
|
||||
|
||||
s := NewHealthService(db, nil)
|
||||
s.timeout = 100 * time.Millisecond
|
||||
|
||||
start := time.Now()
|
||||
report := s.Readiness(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Less(t, elapsed, 1*time.Second, "readiness should respect per-component timeout")
|
||||
require.False(t, report.OK)
|
||||
require.NotEmpty(t, report.Details["database"].Error, "timeout should propagate as an error")
|
||||
}
|
||||
|
||||
// 抑制未使用包警告(database/sql 在签名里使用)。
|
||||
var _ = sql.ErrNoRows
|
||||
@ -8,6 +8,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
|
||||
@ -30,12 +32,22 @@ type OAuthRefreshResult struct {
|
||||
}
|
||||
|
||||
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
|
||||
// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑
|
||||
// 封装分布式锁、进程内去重(singleflight)、DB 重读、已刷新检查、竞争恢复等通用逻辑
|
||||
//
|
||||
// 双层去重设计:
|
||||
// 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine
|
||||
// 都去抢同一把分布式锁、都重读一次 DB)。
|
||||
// 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth
|
||||
// 刷新请求。
|
||||
//
|
||||
// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的,
|
||||
// singleflight 解决不了多 pod 同时刷新。
|
||||
type OAuthRefreshAPI struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache // 可选,nil = 无分布式锁
|
||||
lockTTL time.Duration
|
||||
localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
// NewOAuthRefreshAPI 创建统一刷新 API
|
||||
@ -63,15 +75,19 @@ func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex {
|
||||
return mu
|
||||
}
|
||||
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
|
||||
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。
|
||||
//
|
||||
// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者"
|
||||
// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。
|
||||
//
|
||||
// 流程:
|
||||
// 1. 获取分布式锁
|
||||
// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 3. 二次检查是否仍需刷新
|
||||
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 5. 设置 _token_version + 更新 DB
|
||||
// 6. 释放锁
|
||||
// 1. singleflight 合并同 cacheKey 并发调用
|
||||
// 2. 获取分布式锁(跨进程)
|
||||
// 3. 从 DB 重读最新 account(防止使用过时的 refresh_token)
|
||||
// 4. 二次检查是否仍需刷新
|
||||
// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑
|
||||
// 6. 设置 _token_version + 更新 DB
|
||||
// 7. 释放锁
|
||||
func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
@ -80,11 +96,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded(
|
||||
) (*OAuthRefreshResult, error) {
|
||||
cacheKey := executor.CacheKey(account)
|
||||
|
||||
// 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争)
|
||||
localMu := api.getLocalLock(cacheKey)
|
||||
localMu.Lock()
|
||||
defer localMu.Unlock()
|
||||
// singleflight key 同时区分 cacheKey 和 refreshWindow:
|
||||
// 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh,
|
||||
// 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。
|
||||
sfKey := cacheKey + "|" + refreshWindow.String()
|
||||
|
||||
v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) {
|
||||
return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, _ := v.(*OAuthRefreshResult)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。
|
||||
// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。
|
||||
func (api *OAuthRefreshAPI) refreshOnce(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
executor OAuthRefreshExecutor,
|
||||
refreshWindow time.Duration,
|
||||
cacheKey string,
|
||||
) (*OAuthRefreshResult, error) {
|
||||
// 1. 获取分布式锁
|
||||
lockAcquired := false
|
||||
if api.tokenCache != nil {
|
||||
|
||||
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
160
backend/internal/service/oauth_refresh_api_singleflight_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。
|
||||
type blockingExecutor struct {
|
||||
refreshAPIExecutorStub
|
||||
release chan struct{}
|
||||
concurrent int32 // 当前正在 Refresh 的 goroutine 数
|
||||
maxObserved int32 // 观察到的最大并发数
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
c := atomic.AddInt32(&e.concurrent, 1)
|
||||
for {
|
||||
old := atomic.LoadInt32(&e.maxObserved)
|
||||
if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
defer atomic.AddInt32(&e.concurrent, -1)
|
||||
|
||||
<-e.release
|
||||
return e.credentials, e.err
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) {
|
||||
// 同一 cacheKey 同时进入 N 个 goroutine,应只触发 1 次 executor.Refresh。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
const callers = 20
|
||||
results := make([]*OAuthRefreshResult, callers)
|
||||
errs := make([]error, callers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(callers)
|
||||
|
||||
for i := 0; i < callers; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||
results[i] = r
|
||||
errs[i] = err
|
||||
}()
|
||||
}
|
||||
|
||||
// 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once")
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously")
|
||||
|
||||
// 所有 caller 应拿到等价结果(不必同实例,singleflight Shared 标志会让多个 caller 共享)。
|
||||
for i := 0; i < callers; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.NotNil(t, results[i])
|
||||
require.True(t, results[i].Refreshed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) {
|
||||
// 不同账号有不同 cacheKey,应能并行刷新而非互相阻塞。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
platform := "p" + string(rune('a'+i))
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute)
|
||||
}()
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) {
|
||||
// 同 cacheKey 但不同 refreshWindow(前台短窗口 vs 后台长窗口)应分开判断
|
||||
// NeedsRefresh,避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。
|
||||
repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}}
|
||||
cache := &refreshAPICacheStub{lockResult: true}
|
||||
|
||||
exec := &blockingExecutor{
|
||||
refreshAPIExecutorStub: refreshAPIExecutorStub{
|
||||
needsRefresh: true,
|
||||
credentials: map[string]any{"access_token": "new"},
|
||||
},
|
||||
release: make(chan struct{}),
|
||||
}
|
||||
api := NewOAuthRefreshAPI(repo, cache)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour)
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged")
|
||||
|
||||
close(exec.release)
|
||||
wg.Wait()
|
||||
}
|
||||
@ -730,12 +730,18 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
quotaFactor := item.account.GetQuotaRemainingFraction()
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
weights.TTFT*ttftFactor +
|
||||
weights.Quota*quotaFactor
|
||||
}
|
||||
|
||||
if s.service.openAIWSP2CEnabled() {
|
||||
return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1193,6 +1199,7 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
|
||||
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
||||
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
||||
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
||||
Quota: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota,
|
||||
}
|
||||
}
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
@ -1201,15 +1208,21 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul
|
||||
Queue: 0.7,
|
||||
ErrorRate: 0.8,
|
||||
TTFT: 0.5,
|
||||
Quota: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSP2CEnabled() bool {
|
||||
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EnableP2CScheduling
|
||||
}
|
||||
|
||||
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
||||
Priority float64
|
||||
Load float64
|
||||
Queue float64
|
||||
ErrorRate float64
|
||||
TTFT float64
|
||||
Quota float64
|
||||
}
|
||||
|
||||
func clamp01(value float64) float64 {
|
||||
@ -1223,6 +1236,94 @@ func clamp01(value float64) float64 {
|
||||
}
|
||||
}
|
||||
|
||||
// selectByPowerOfTwo implements Power-of-Two-Choices (P2C): sample 2 random
|
||||
// candidates and attempt the better-scored one first, then the other.
|
||||
// This gives O(1) selection with load distribution comparable to top-K when N is large.
|
||||
func (s *defaultOpenAIAccountScheduler) selectByPowerOfTwo(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
candidates []openAIAccountCandidateScore,
|
||||
loadSkew float64,
|
||||
) (*AccountSelectionResult, int, int, float64, error) {
|
||||
n := len(candidates)
|
||||
if n == 0 {
|
||||
return nil, 0, 0, loadSkew, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
|
||||
|
||||
// Pick two distinct random indices.
|
||||
idxA := int(rng.nextUint64() % uint64(n))
|
||||
idxB := idxA
|
||||
if n > 1 {
|
||||
for idxB == idxA {
|
||||
idxB = int(rng.nextUint64() % uint64(n))
|
||||
}
|
||||
}
|
||||
|
||||
// Order: better candidate first.
|
||||
first, second := candidates[idxA], candidates[idxB]
|
||||
if isOpenAIAccountCandidateBetter(second, first) {
|
||||
first, second = second, first
|
||||
}
|
||||
|
||||
tryAcquire := func(c openAIAccountCandidateScore) (*AccountSelectionResult, bool, error) {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
return nil, false, nil
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, req.RequireCompact)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
return nil, false, nil
|
||||
}
|
||||
result, err := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
for _, c := range []openAIAccountCandidateScore{first, second} {
|
||||
result, ok, err := tryAcquire(c)
|
||||
if err != nil {
|
||||
return nil, n, 2, loadSkew, err
|
||||
}
|
||||
if ok {
|
||||
return result, n, 2, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Both slots busy — return wait plan on the better candidate.
|
||||
cfg := s.service.schedulingConfig()
|
||||
for _, c := range []openAIAccountCandidateScore{first, second} {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel, req.RequireCompact)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: fresh.ID,
|
||||
MaxConcurrency: fresh.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, n, 2, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, n, 2, loadSkew, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
||||
if count <= 1 {
|
||||
return 0
|
||||
|
||||
@ -1448,3 +1448,129 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
|
||||
func int64PtrForTest(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestAccount_GetQuotaRemainingFraction(t *testing.T) {
|
||||
// No limit configured → always 1.0 (unlimited)
|
||||
noLimit := &Account{}
|
||||
require.Equal(t, 1.0, noLimit.GetQuotaRemainingFraction())
|
||||
|
||||
// 50% used
|
||||
half := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 50.0}}
|
||||
require.InDelta(t, 0.5, half.GetQuotaRemainingFraction(), 1e-9)
|
||||
|
||||
// Fully exhausted
|
||||
full := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 100.0}}
|
||||
require.Equal(t, 0.0, full.GetQuotaRemainingFraction())
|
||||
|
||||
// Over limit → clamp to 0
|
||||
over := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 150.0}}
|
||||
require.Equal(t, 0.0, over.GetQuotaRemainingFraction())
|
||||
|
||||
// Fresh (0 used)
|
||||
fresh := &Account{Extra: map[string]any{"quota_limit": 200.0, "quota_used": 0.0}}
|
||||
require.Equal(t, 1.0, fresh.GetQuotaRemainingFraction())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_P2CEnabled(t *testing.T) {
|
||||
require.False(t, (*OpenAIGatewayService)(nil).openAIWSP2CEnabled())
|
||||
require.False(t, (&OpenAIGatewayService{}).openAIWSP2CEnabled())
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.EnableP2CScheduling = false
|
||||
require.False(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
|
||||
|
||||
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||
require.True(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SchedulerWeights_QuotaField(t *testing.T) {
|
||||
// Default weights: Quota is 0 (disabled by default)
|
||||
svc := &OpenAIGatewayService{}
|
||||
weights := svc.openAIWSSchedulerWeights()
|
||||
require.Equal(t, 0.0, weights.Quota)
|
||||
|
||||
// Config-driven quota weight
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota = 0.4
|
||||
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
|
||||
require.Equal(t, 0.4, svcWithCfg.openAIWSSchedulerWeights().Quota)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_SingleCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(99001)
|
||||
account := &Account{ID: 71001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
snapshotAccounts: []*Account{account},
|
||||
accountsByID: map[int64]*Account{71001: account},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 5
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.Equal(t, int64(71001), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_PicksBetterCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(99002)
|
||||
// Account A has low priority (better), B has high priority (worse).
|
||||
// With P2C enabled and a deterministic seed, we should always get a valid selection.
|
||||
accountA := &Account{ID: 72001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0}
|
||||
accountB := &Account{ID: 72002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 10}
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
snapshotAccounts: []*Account{accountA, accountB},
|
||||
accountsByID: map[int64]*Account{72001: accountA, 72002: accountB},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.EnableP2CScheduling = true
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*accountA, *accountB}},
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
// Either account is valid; just verify we got a schedulable one.
|
||||
require.True(t, selection.Account.ID == 72001 || selection.Account.ID == 72002)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_QuotaFactorInfluencesScore(t *testing.T) {
|
||||
// Verify that quota weight affects scoring by checking GetQuotaRemainingFraction is used.
|
||||
// Account with high remaining quota should score higher when quota weight > 0.
|
||||
highQuota := &Account{
|
||||
ID: 73001, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
|
||||
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
|
||||
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 10.0}, // 90% remaining
|
||||
}
|
||||
lowQuota := &Account{
|
||||
ID: 73002, Platform: PlatformOpenAI, Type: AccountTypeOAuth,
|
||||
Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
|
||||
Extra: map[string]any{"quota_limit": 100.0, "quota_used": 90.0}, // 10% remaining
|
||||
}
|
||||
|
||||
require.InDelta(t, 0.9, highQuota.GetQuotaRemainingFraction(), 1e-9)
|
||||
require.InDelta(t, 0.1, lowQuota.GetQuotaRemainingFraction(), 1e-9)
|
||||
|
||||
// With quota weight = 1.0 and all other weights = 0, high-quota account should win.
|
||||
// We verify the score ordering directly using isOpenAIAccountCandidateBetter.
|
||||
highScore := openAIAccountCandidateScore{account: highQuota, score: 0.9}
|
||||
lowScore := openAIAccountCandidateScore{account: lowQuota, score: 0.1}
|
||||
require.True(t, isOpenAIAccountCandidateBetter(highScore, lowScore))
|
||||
require.False(t, isOpenAIAccountCandidateBetter(lowScore, highScore))
|
||||
}
|
||||
|
||||
75
backend/internal/service/request_event_bus.go
Normal file
75
backend/internal/service/request_event_bus.go
Normal file
@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const requestEventBufSize = 64
|
||||
|
||||
// RequestEvent is published for every gateway dispatch completion.
|
||||
type RequestEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
Model string `json:"model"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
// Status is "success", "error", or "rate_limited".
|
||||
Status string `json:"status"`
|
||||
LatencyMS int64 `json:"latency_ms"`
|
||||
}
|
||||
|
||||
// RequestEventBus is a fan-out hub for real-time request events.
|
||||
// Publishers call Publish; subscribers call Subscribe/Unsubscribe.
|
||||
// Each subscriber gets its own buffered channel. If the buffer is full
|
||||
// the event is dropped for that subscriber (non-blocking publish).
|
||||
type RequestEventBus struct {
|
||||
mu sync.RWMutex
|
||||
subscribers map[uint64]chan RequestEvent
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
func NewRequestEventBus() *RequestEventBus {
|
||||
return &RequestEventBus{
|
||||
subscribers: make(map[uint64]chan RequestEvent),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new subscriber and returns its ID and a receive-only channel.
|
||||
func (b *RequestEventBus) Subscribe() (uint64, <-chan RequestEvent) {
|
||||
id := b.nextID.Add(1)
|
||||
ch := make(chan RequestEvent, requestEventBufSize)
|
||||
b.mu.Lock()
|
||||
b.subscribers[id] = ch
|
||||
b.mu.Unlock()
|
||||
return id, ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscriber and closes its channel.
|
||||
func (b *RequestEventBus) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
ch, ok := b.subscribers[id]
|
||||
if ok {
|
||||
delete(b.subscribers, id)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
if ok {
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish sends an event to all current subscribers without blocking.
|
||||
func (b *RequestEventBus) Publish(e RequestEvent) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
for _, ch := range b.subscribers {
|
||||
select {
|
||||
case ch <- e:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
100
backend/internal/service/request_event_bus_test.go
Normal file
100
backend/internal/service/request_event_bus_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestEventBus_PublishToSubscriber(t *testing.T) {
|
||||
bus := NewRequestEventBus()
|
||||
|
||||
id, ch := bus.Subscribe()
|
||||
defer bus.Unsubscribe(id)
|
||||
|
||||
evt := RequestEvent{Model: "claude-3", Status: "success", LatencyMS: 100}
|
||||
bus.Publish(evt)
|
||||
|
||||
select {
|
||||
case got := <-ch:
|
||||
assert.Equal(t, evt, got)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestEventBus_MultipleSubscribers(t *testing.T) {
|
||||
bus := NewRequestEventBus()
|
||||
|
||||
id1, ch1 := bus.Subscribe()
|
||||
id2, ch2 := bus.Subscribe()
|
||||
defer bus.Unsubscribe(id1)
|
||||
defer bus.Unsubscribe(id2)
|
||||
|
||||
evt := RequestEvent{Model: "claude-3", Status: "error"}
|
||||
bus.Publish(evt)
|
||||
|
||||
for _, ch := range []<-chan RequestEvent{ch1, ch2} {
|
||||
select {
|
||||
case got := <-ch:
|
||||
assert.Equal(t, evt, got)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event on one subscriber")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestEventBus_UnsubscribeClosesChannel(t *testing.T) {
|
||||
bus := NewRequestEventBus()
|
||||
id, ch := bus.Subscribe()
|
||||
|
||||
bus.Unsubscribe(id)
|
||||
|
||||
// Channel should be closed.
|
||||
_, ok := <-ch
|
||||
assert.False(t, ok, "channel should be closed after Unsubscribe")
|
||||
}
|
||||
|
||||
func TestRequestEventBus_UnsubscribedMissesEvents(t *testing.T) {
|
||||
bus := NewRequestEventBus()
|
||||
id, _ := bus.Subscribe()
|
||||
bus.Unsubscribe(id)
|
||||
|
||||
// Publish after unsubscribe should not panic.
|
||||
require.NotPanics(t, func() {
|
||||
bus.Publish(RequestEvent{Model: "test"})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestEventBus_DropWhenFull(t *testing.T) {
|
||||
bus := NewRequestEventBus()
|
||||
id, ch := bus.Subscribe()
|
||||
defer bus.Unsubscribe(id)
|
||||
|
||||
// Fill the buffer then publish one more — should drop, not block.
|
||||
evt := RequestEvent{Model: "model", Status: "success"}
|
||||
for i := 0; i < requestEventBufSize; i++ {
|
||||
bus.Publish(evt)
|
||||
}
|
||||
// This publish should return immediately (dropped).
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bus.Publish(evt)
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Publish blocked when buffer was full")
|
||||
}
|
||||
assert.Len(t, ch, requestEventBufSize)
|
||||
}
|
||||
|
||||
func TestRequestEventBus_NilSafePublish(t *testing.T) {
|
||||
var bus *RequestEventBus
|
||||
require.NotPanics(t, func() {
|
||||
bus.Publish(RequestEvent{Model: "test"})
|
||||
})
|
||||
}
|
||||
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
|
||||
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
|
||||
|
||||
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
|
||||
// When an account's RPM budget is exhausted, callers can wait up to a configured
|
||||
// deadline instead of receiving an immediate 429. The bucket refills continuously
|
||||
// at rpm/60 tokens per second so requests are distributed evenly over time.
|
||||
type RPMTokenBucketService struct {
|
||||
buckets sync.Map // map[int64]*rpmEntry
|
||||
}
|
||||
|
||||
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
|
||||
func NewRPMTokenBucketService() *RPMTokenBucketService {
|
||||
return &RPMTokenBucketService{}
|
||||
}
|
||||
|
||||
type rpmEntry struct {
|
||||
bucket *tokenBucket
|
||||
rpm int
|
||||
}
|
||||
|
||||
// getBucket returns (or creates) the token bucket for accountID.
|
||||
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
|
||||
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
|
||||
if v, ok := s.buckets.Load(accountID); ok {
|
||||
e := v.(*rpmEntry)
|
||||
if e.rpm == rpm {
|
||||
return e.bucket
|
||||
}
|
||||
// RPM limit changed — replace with a fresh bucket.
|
||||
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||
s.buckets.Store(accountID, fresh)
|
||||
return fresh.bucket
|
||||
}
|
||||
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||
actual, _ := s.buckets.LoadOrStore(accountID, entry)
|
||||
return actual.(*rpmEntry).bucket
|
||||
}
|
||||
|
||||
// AcquireWithWait attempts to consume one token for the given account.
|
||||
// It blocks up to maxWait for a token to become available.
|
||||
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
|
||||
// or ctx.Err() if the context is cancelled.
|
||||
// If rpm <= 0 the call returns immediately with nil.
|
||||
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||
if rpm <= 0 {
|
||||
return nil
|
||||
}
|
||||
bucket := s.getBucket(accountID, rpm)
|
||||
deadline := time.Now().Add(maxWait)
|
||||
|
||||
for {
|
||||
ok, waitDur := bucket.tryAcquire()
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 || waitDur > remaining {
|
||||
return ErrRPMWaitTimeout
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(waitDur):
|
||||
// token may be available now; retry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tokenBucket is a continuous-refill token bucket for a single account.
|
||||
type tokenBucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
maxTokens float64
|
||||
rateSec float64 // tokens refilled per second = rpm / 60
|
||||
lastFill time.Time
|
||||
}
|
||||
|
||||
func newTokenBucket(rpm int) *tokenBucket {
|
||||
max := float64(rpm)
|
||||
return &tokenBucket{
|
||||
tokens: max,
|
||||
maxTokens: max,
|
||||
rateSec: float64(rpm) / 60.0,
|
||||
lastFill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
|
||||
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
|
||||
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(b.lastFill).Seconds()
|
||||
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
|
||||
b.lastFill = now
|
||||
|
||||
if b.tokens >= 1.0 {
|
||||
b.tokens -= 1.0
|
||||
return true, 0
|
||||
}
|
||||
|
||||
deficit := 1.0 - b.tokens
|
||||
waitSecs := deficit / b.rateSec
|
||||
return false, time.Duration(waitSecs * float64(time.Second))
|
||||
}
|
||||
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
// Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately.
|
||||
for i := 0; i < 60; i++ {
|
||||
err := svc.AcquireWithWait(ctx, 1, 60, 0)
|
||||
require.NoError(t, err, "call %d should succeed immediately", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
err := svc.AcquireWithWait(context.Background(), 42, 0, 0)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// rpm=1 → 1 token/minute. One call drains the bucket.
|
||||
err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second)
|
||||
require.NoError(t, err, "first call should succeed")
|
||||
|
||||
// Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms.
|
||||
start := time.Now()
|
||||
err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
assert.ErrorIs(t, err, ErrRPMWaitTimeout)
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully.
|
||||
for i := 0; i < 120; i++ {
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0))
|
||||
}
|
||||
|
||||
// Next call needs to wait ~500ms for the next token. Give it 2s.
|
||||
start := time.Now()
|
||||
err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
require.NoError(t, err, "should succeed after waiting for refill")
|
||||
assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited")
|
||||
assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_ContextCancellation(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
|
||||
// rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining.
|
||||
// maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms).
|
||||
// Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first.
|
||||
for i := 0; i < 120; i++ {
|
||||
require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// Drain account 1 (rpm=1).
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0))
|
||||
|
||||
// Account 2 has its own bucket and should succeed immediately.
|
||||
err := svc.AcquireWithWait(ctx, 2, 1, 0)
|
||||
assert.NoError(t, err, "different account should have an independent bucket")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create bucket with rpm=1 and drain it.
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0))
|
||||
// Bucket now empty with rpm=1.
|
||||
|
||||
// Changing RPM to 60 should reset the bucket to full (60 tokens).
|
||||
err := svc.AcquireWithWait(ctx, 10, 60, 0)
|
||||
assert.NoError(t, err, "new RPM should cause bucket recreation")
|
||||
}
|
||||
@ -426,6 +426,8 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideBillingCacheService,
|
||||
NewAnnouncementService,
|
||||
NewAdminService,
|
||||
NewRPMTokenBucketService,
|
||||
NewRequestEventBus,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
@ -448,6 +450,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideSettingService,
|
||||
NewDataManagementService,
|
||||
ProvideBackupService,
|
||||
NewHealthService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
ProvideOpsMetricsCollector,
|
||||
|
||||
@ -120,7 +120,7 @@ services:
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: [ "CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health" ]
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/ready"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user