From d6df41feaaecd94caceaeee112a8a58b4d7f2e8a Mon Sep 17 00:00:00 2001 From: win Date: Tue, 28 Apr 2026 22:35:24 +0800 Subject: [PATCH 1/8] chore(claude): bump CLI fingerprint to 2.1.88 and accept claude-code/ UA - Centralize Claude CLI fingerprint constants (UA, x-stainless-*) in pkg/claude with BuildCLI/CodeUserAgent helpers - Reuse constants in DefaultHeaders, identity_service defaults, and antigravity identity defaults to keep all callers in sync - Extend ClaudeCodeValidator to accept both claude-cli/ and claude-code/ UA prefixes (transport/helper requests use the latter) - Update related tests to cover the new UA prefix and version --- backend/internal/pkg/claude/constants.go | 58 ++++++++++++++++--- .../repository/claude_usage_service.go | 5 +- .../service/claude_code_detection_test.go | 14 +++++ .../internal/service/claude_code_validator.go | 10 ++-- .../service/claude_code_validator_test.go | 1 + ...teway_anthropic_apikey_passthrough_test.go | 35 +++++++++++ .../internal/service/gateway_prompt_test.go | 6 ++ backend/internal/service/gateway_service.go | 28 +++++---- backend/internal/service/identity_service.go | 15 ++--- .../service/identity_service_antigravity.go | 4 +- 10 files changed, 139 insertions(+), 37 deletions(-) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index c9c015bb..95f44630 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -1,8 +1,24 @@ // Package claude provides constants and helpers for Claude API integration. package claude +import "strings" + // Claude Code 客户端相关常量 +const ( + DefaultCLIProductVersion = "2.1.88" + DefaultUserType = "external" + DefaultEntrypoint = "cli" + DefaultStainlessLang = "js" + DefaultStainlessPackageVersion = "0.74.0" + DefaultStainlessOS = "MacOS" + DefaultStainlessArch = "arm64" + DefaultStainlessRuntime = "node" + DefaultStainlessRuntimeVersion = "v24.3.0" + DefaultCLIUserAgent = "claude-cli/" + DefaultCLIProductVersion + " (" + DefaultUserType + ", " + DefaultEntrypoint + ")" + DefaultCodeUserAgent = "claude-code/" + DefaultCLIProductVersion +) + // Beta header 常量 const ( BetaOAuth = "oauth-2025-04-20" @@ -52,19 +68,45 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking var DefaultHeaders = map[string]string{ // Keep these in sync with recent Claude CLI traffic to reduce the chance // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage. - "User-Agent": "claude-cli/2.1.84 (external, cli)", - "X-Stainless-Lang": "js", - "X-Stainless-Package-Version": "0.74.0", - "X-Stainless-OS": "MacOS", - "X-Stainless-Arch": "arm64", - "X-Stainless-Runtime": "node", - "X-Stainless-Runtime-Version": "v24.3.0", + "User-Agent": DefaultCLIUserAgent, + "X-Stainless-Lang": DefaultStainlessLang, + "X-Stainless-Package-Version": DefaultStainlessPackageVersion, + "X-Stainless-OS": DefaultStainlessOS, + "X-Stainless-Arch": DefaultStainlessArch, + "X-Stainless-Runtime": DefaultStainlessRuntime, + "X-Stainless-Runtime-Version": DefaultStainlessRuntimeVersion, "X-Stainless-Retry-Count": "0", "X-Stainless-Timeout": "600", "X-App": "cli", "Anthropic-Dangerous-Direct-Browser-Access": "true", } +// BuildCLIUserAgent returns the current Claude Code API client user-agent. +func BuildCLIUserAgent(version, userType, entrypoint string) string { + version = strings.TrimSpace(version) + if version == "" { + version = DefaultCLIProductVersion + } + userType = strings.TrimSpace(userType) + if userType == "" { + userType = DefaultUserType + } + entrypoint = strings.TrimSpace(entrypoint) + if entrypoint == "" { + entrypoint = DefaultEntrypoint + } + return "claude-cli/" + version + " (" + userType + ", " + entrypoint + ")" +} + +// BuildCodeUserAgent returns the current Claude Code transport/helper user-agent. +func BuildCodeUserAgent(version string) string { + version = strings.TrimSpace(version) + if version == "" { + version = DefaultCLIProductVersion + } + return "claude-code/" + version +} + // ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值) // cliVersion: Claude CLI 版本(如 "2.1.81") // pkgVersion: SDK 版本(如 "0.80.0") @@ -73,7 +115,7 @@ var DefaultHeaders = map[string]string{ // arch: 架构(如 "arm64") func ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) { if cliVersion != "" { - DefaultHeaders["User-Agent"] = "claude-cli/" + cliVersion + " (external, cli)" + DefaultHeaders["User-Agent"] = BuildCLIUserAgent(cliVersion, "", "") } if pkgVersion != "" { DefaultHeaders["X-Stainless-Package-Version"] = pkgVersion diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index b44adde2..15329507 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" @@ -15,8 +16,8 @@ import ( const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" -// 默认 User-Agent,与用户抓包的请求一致 -const defaultUsageUserAgent = "claude-code/2.1.7" +// 默认 User-Agent,与 Claude Code 2.1.88 的 helper/transport 请求保持一致。 +const defaultUsageUserAgent = claude.DefaultCodeUserAgent type claudeUsageService struct { usageURL string diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go index ff7ad7f4..463aa60d 100644 --- a/backend/internal/service/claude_code_detection_test.go +++ b/backend/internal/service/claude_code_detection_test.go @@ -40,6 +40,7 @@ func TestValidate_ClaudeCLIUserAgent(t *testing.T) { want bool }{ {"标准版本号", "claude-cli/1.0.0", true}, + {"官方 transport UA", "claude-code/2.1.88", true}, {"多位版本号", "claude-cli/12.34.56", true}, {"大写开头", "Claude-CLI/1.0.0", true}, {"非 claude-cli", "curl/7.64.1", false}, @@ -90,6 +91,19 @@ func TestValidate_MessagesPath_FullValid(t *testing.T) { require.True(t, result, "完整有效请求应通过") } +func TestValidate_MessagesPath_FullValid_ClaudeCodeUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-code/2.1.88") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "官方 transport/helper UA 也应通过") +} + func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { v := newTestValidator() body := validClaudeCodeBody() diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index 4e8ced67..a40ccb95 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -15,11 +15,13 @@ import ( type ClaudeCodeValidator struct{} var ( - // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) - claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) + // User-Agent 匹配: 官方 Claude Code 目前存在两类产品前缀: + // 1. 主 Anthropic API 客户端: claude-cli/x.y.z (...) + // 2. transport / helper 请求: claude-code/x.y.z + claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/\d+\.\d+\.\d+`) // 带捕获组的版本提取正则 - claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) + claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-(?:cli|code)/(\d+\.\d+\.\d+)`) // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) systemPromptThreshold = 0.5 @@ -55,7 +57,7 @@ func NewClaudeCodeValidator() *ClaudeCodeValidator { // Validate 验证请求是否来自 Claude Code CLI // 采用与 claude-relay-service 完全一致的验证策略: // -// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x +// Step 1: User-Agent 检查 (必需) - 必须是官方 claude-cli/ 或 claude-code/ 前缀 // Step 2: 对于非 messages 路径,只要 UA 匹配就通过 // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证) // Step 4: 对于 messages 路径,进行严格验证: diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go index f87c56e8..fd7e26da 100644 --- a/backend/internal/service/claude_code_validator_test.go +++ b/backend/internal/service/claude_code_validator_test.go @@ -64,6 +64,7 @@ func TestExtractVersion(t *testing.T) { want string }{ {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"}, + {"claude-code/2.1.88", "2.1.88"}, {"claude-cli/1.0.0", "1.0.0"}, {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感 {"curl/8.0.0", ""}, // 非 Claude CLI diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 6e19db32..9b962455 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -689,6 +689,41 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t require.Contains(t, getHeaderRaw(req.Header, "anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") } +func TestGatewayService_AnthropicOAuth_InjectsClaudeCodeSessionHeaderFromMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + sessionID := "12345678-1234-1234-1234-123456789abc" + body, err := json.Marshal(map[string]any{ + "model": "claude-3-7-sonnet-20250219", + "metadata": map[string]any{ + "user_id": FormatMetadataUserID( + "d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169", + "", + sessionID, + claude.DefaultCLIProductVersion, + ), + }, + }) + require.NoError(t, err) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, body, "oauth-token", "oauth", "claude-3-7-sonnet-20250219", false, false) + require.NoError(t, err) + require.Equal(t, sessionID, getHeaderRaw(req.Header, "X-Claude-Code-Session-Id")) +} + func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index 356536b0..e52a1a81 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -21,6 +21,12 @@ func TestIsClaudeCodeClient(t *testing.T) { metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", want: true, }, + { + name: "Claude Code helper client", + userAgent: "claude-code/2.1.88", + metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + want: true, + }, { name: "Claude Code without version suffix", userAgent: "claude-cli/2.0.0", diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b54f463b..107e5086 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -328,8 +328,8 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( - sseDataRe = regexp.MustCompile(`^data:\s*`) - claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + sseDataRe = regexp.MustCompile(`^data:\s*`) + claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -3739,7 +3739,7 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { if metadataUserID == "" { return false } - return claudeCliUserAgentRe.MatchString(userAgent) + return claudeCodeUserAgentRe.MatchString(userAgent) } func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { @@ -5758,12 +5758,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } + // Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。 + // 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。 + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) } } @@ -8486,12 +8485,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } + // Claude Code 主 API 客户端会始终发送 X-Claude-Code-Session-Id。 + // 对于 mimic / 转发场景,只要 body 中 metadata.user_id 可解析,就主动注入并同步该头。 + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) } } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index c6a260a8..43351f73 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -26,13 +27,13 @@ var ( // 默认指纹值(当客户端未提供时使用) var defaultFingerprint = Fingerprint{ - UserAgent: "claude-cli/2.1.84 (external, cli)", - StainlessLang: "js", - StainlessPackageVersion: "0.74.0", - StainlessOS: "MacOS", - StainlessArch: "arm64", - StainlessRuntime: "node", - StainlessRuntimeVersion: "v24.3.0", + UserAgent: claude.DefaultCLIUserAgent, + StainlessLang: claude.DefaultStainlessLang, + StainlessPackageVersion: claude.DefaultStainlessPackageVersion, + StainlessOS: claude.DefaultStainlessOS, + StainlessArch: claude.DefaultStainlessArch, + StainlessRuntime: claude.DefaultStainlessRuntime, + StainlessRuntimeVersion: claude.DefaultStainlessRuntimeVersion, } // Fingerprint represents account fingerprint data diff --git a/backend/internal/service/identity_service_antigravity.go b/backend/internal/service/identity_service_antigravity.go index e725a7fb..8416aff1 100644 --- a/backend/internal/service/identity_service_antigravity.go +++ b/backend/internal/service/identity_service_antigravity.go @@ -1,5 +1,7 @@ package service +import "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + // ============================================================== // antigravity — identity_service 扩展 // @@ -15,7 +17,7 @@ package service // 允许不同部署实例设置不同的 CLI/SDK 版本号,避免所有实例指纹相同 func ApplyDefaultFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) { if cliVersion != "" { - defaultFingerprint.UserAgent = "claude-cli/" + cliVersion + " (external, cli)" + defaultFingerprint.UserAgent = claude.BuildCLIUserAgent(cliVersion, "", "") } if pkgVersion != "" { defaultFingerprint.StainlessPackageVersion = pkgVersion From 110902ad4b76c9d3e2c863504454316ad0ce20df Mon Sep 17 00:00:00 2001 From: win Date: Tue, 28 Apr 2026 23:39:50 +0800 Subject: [PATCH 2/8] feat(health): split liveness and readiness probes Add HealthService with Liveness (no-op) and Readiness (DB+Redis ping with per-component timeout) checks. Expose three endpoints: - /healthz : new liveness endpoint, zero-dependency, always 200 - /ready : new readiness endpoint, returns 503 with details on dep failure; suitable for K8s readinessProbe and load balancers - /health : preserved for backward compatibility, equivalent to /healthz Switch primary docker-compose healthcheck to /ready so the container is only marked healthy once DB+Redis are reachable. Standalone/dev/ local compose files keep /health to avoid disrupting existing setups. Tests: unit tests cover liveness, readiness with both deps healthy, each dep failing independently, and per-component timeout enforcement. --- backend/cmd/server/wire_gen.go | 3 +- backend/go.mod | 2 + backend/go.sum | 2 + backend/internal/server/http.go | 3 +- backend/internal/server/router.go | 6 +- backend/internal/server/routes/common.go | 37 +++++- backend/internal/server/routes/common_test.go | 49 ++++++++ backend/internal/service/health_service.go | 119 ++++++++++++++++++ .../internal/service/health_service_test.go | 93 ++++++++++++++ backend/internal/service/wire.go | 1 + deploy/docker-compose.yml | 2 +- 11 files changed, 308 insertions(+), 9 deletions(-) create mode 100644 backend/internal/server/routes/common_test.go create mode 100644 backend/internal/service/health_service.go create mode 100644 backend/internal/service/health_service_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index d0dcacd2..ffb53780 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -237,7 +237,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient) + healthService := service.NewHealthService(db, redisClient) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient) httpServer := server.ProvideHTTPServer(configConfig, engine) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) diff --git a/backend/go.mod b/backend/go.mod index 135cbd3e..509619b1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -107,6 +107,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -176,6 +177,7 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index f5b7968f..c8102f65 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -180,6 +180,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index a8034e98..a9e3524c 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -35,6 +35,7 @@ func ProvideRouter( subscriptionService *service.SubscriptionService, opsService *service.OpsService, settingService *service.SettingService, + healthService *service.HealthService, redisClient *redis.Client, ) *gin.Engine { if cfg.Server.Mode == "release" { @@ -56,7 +57,7 @@ func ProvideRouter( } } - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 99701531..d532bd7f 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -30,6 +30,7 @@ func SetupRouter( subscriptionService *service.SubscriptionService, opsService *service.OpsService, settingService *service.SettingService, + healthService *service.HealthService, cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { @@ -81,7 +82,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient) return r } @@ -97,11 +98,12 @@ func registerRoutes( subscriptionService *service.SubscriptionService, opsService *service.OpsService, settingService *service.SettingService, + healthService *service.HealthService, cfg *config.Config, redisClient *redis.Client, ) { // 通用路由(健康检查、状态等) - routes.RegisterCommonRoutes(r) + routes.RegisterCommonRoutes(r, healthService) // API v1 v1 := r.Group("/api/v1") diff --git a/backend/internal/server/routes/common.go b/backend/internal/server/routes/common.go index 4989358d..bd71dc12 100644 --- a/backend/internal/server/routes/common.go +++ b/backend/internal/server/routes/common.go @@ -1,16 +1,45 @@ package routes import ( + "context" "net/http" + "time" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) -// RegisterCommonRoutes 注册通用路由(健康检查、状态等) -func RegisterCommonRoutes(r *gin.Engine) { - // 健康检查 - r.GET("/health", func(c *gin.Context) { +// readinessHandlerTimeout 限定 readiness 端点对外的最大返回耗时。 +// HealthService 内部对每个组件再有独立超时,所以这里给宽一点即可。 +const readinessHandlerTimeout = 3 * time.Second + +// RegisterCommonRoutes 注册通用路由(健康检查、状态等)。 +// +// 健康端点的语义分层: +// - /healthz : liveness 探针。零依赖、永远 200。容器/进程探活专用。 +// - /ready : readiness 探针。检查 DB+Redis;任一失败返回 503。 +// - /health : 历史端点,等价于 /healthz,保留向后兼容。 +// +// dashboard 用的"业务健康分"由 ops_health_score 单独提供,与本路由无关。 +func RegisterCommonRoutes(r *gin.Engine, healthService *service.HealthService) { + // Liveness:仅证明进程在响应。 + livenessHandler := func(c *gin.Context) { + _ = healthService.Liveness() c.JSON(http.StatusOK, gin.H{"status": "ok"}) + } + r.GET("/healthz", livenessHandler) + r.GET("/health", livenessHandler) // 向后兼容旧的 docker-compose healthcheck + + // Readiness:检查关键依赖。失败时返回 503 但仍带详情,便于排障。 + r.GET("/ready", func(c *gin.Context) { + ctx, cancel := context.WithTimeout(c.Request.Context(), readinessHandlerTimeout) + defer cancel() + report := healthService.Readiness(ctx) + status := http.StatusOK + if !report.OK { + status = http.StatusServiceUnavailable + } + c.JSON(status, report) }) // Claude Code 遥测日志(忽略,直接返回200) diff --git a/backend/internal/server/routes/common_test.go b/backend/internal/server/routes/common_test.go new file mode 100644 index 00000000..51a5e43c --- /dev/null +++ b/backend/internal/server/routes/common_test.go @@ -0,0 +1,49 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newTestRouter(t *testing.T, hs *service.HealthService) *gin.Engine { + t.Helper() + gin.SetMode(gin.TestMode) + r := gin.New() + RegisterCommonRoutes(r, hs) + return r +} + +func TestCommonRoutes_LivenessEndpoints(t *testing.T) { + r := newTestRouter(t, service.NewHealthService(nil, nil)) + for _, path := range []string{"/healthz", "/health"} { + req := httptest.NewRequest(http.MethodGet, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code, "liveness path %s should be 200", path) + } +} + +func TestCommonRoutes_ReadyEndpoint_NoDepsReturnsOK(t *testing.T) { + // 没有 DB/Redis 依赖时 readiness 视为 ok(早期启动场景)。 + r := newTestRouter(t, service.NewHealthService(nil, nil)) + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + require.Contains(t, w.Body.String(), "\"ok\":true") +} + +func TestCommonRoutes_SetupStatusUnchanged(t *testing.T) { + // 验证我们没有破坏既有的 /setup/status 行为(前端依赖)。 + r := newTestRouter(t, service.NewHealthService(nil, nil)) + req := httptest.NewRequest(http.MethodGet, "/setup/status", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + require.Contains(t, w.Body.String(), "needs_setup") +} diff --git a/backend/internal/service/health_service.go b/backend/internal/service/health_service.go new file mode 100644 index 00000000..eb709da2 --- /dev/null +++ b/backend/internal/service/health_service.go @@ -0,0 +1,119 @@ +// Package service - HealthService 提供 liveness 与 readiness 探针。 +// +// 设计动机:原有 /health 端点既被 docker-compose healthcheck 使用,又被 +// dashboard 的 ops_health_score 复用——后者会触发 DB/Redis 等重操作, +// 导致探活流量污染监控指标。本服务把两类语义拆开: +// - Liveness : 仅证明进程存活(无外部依赖检查)。 +// - Readiness : 检查 DB + Redis 连通,作为是否可接收流量的判断。 +// +// dashboard 维度的"业务健康分"仍由 ops_health_score 计算,与本服务无关。 +package service + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +// 探针默认超时。Readiness 探针需要快速失败,避免堆积。 +const ( + defaultReadinessTimeout = 2 * time.Second +) + +// ReadinessReport 描述各依赖项的状态,便于上层暴露细节给排障。 +type ReadinessReport struct { + OK bool `json:"ok"` + Details map[string]ComponentStatus `json:"details"` + Elapsed time.Duration `json:"elapsed_ms"` +} + +// ComponentStatus 单个依赖项的状态。Error 字段在 OK=true 时为空。 +type ComponentStatus struct { + OK bool `json:"ok"` + Error string `json:"error,omitempty"` + Elapsed string `json:"elapsed,omitempty"` +} + +// HealthService 提供 liveness/readiness 探针。 +// 字段都允许为 nil:缺失的依赖在 readiness 中自动跳过,便于测试和分阶段启用。 +type HealthService struct { + db *sql.DB + rdb *redis.Client + timeout time.Duration +} + +// NewHealthService 构造函数。timeout<=0 时使用默认值。 +func NewHealthService(db *sql.DB, rdb *redis.Client) *HealthService { + return &HealthService{ + db: db, + rdb: rdb, + timeout: defaultReadinessTimeout, + } +} + +// Liveness 仅返回 nil。任何调用方能拿到这个返回值就说明进程在响应请求。 +// 保持无副作用、零依赖,便于 K8s livenessProbe 高频调用。 +func (s *HealthService) Liveness() error { + return nil +} + +// Readiness 检查所有外部依赖。任一失败则整体 OK=false。 +// 单个依赖的 ctx 超时由 timeout 控制,独立计时不互相阻塞。 +func (s *HealthService) Readiness(ctx context.Context) ReadinessReport { + start := time.Now() + report := ReadinessReport{ + OK: true, + Details: make(map[string]ComponentStatus, 2), + } + + if s.db != nil { + report.Details["database"] = s.checkDB(ctx) + if !report.Details["database"].OK { + report.OK = false + } + } + if s.rdb != nil { + report.Details["redis"] = s.checkRedis(ctx) + if !report.Details["redis"].OK { + report.OK = false + } + } + + report.Elapsed = time.Since(start) + return report +} + +func (s *HealthService) checkDB(parent context.Context) ComponentStatus { + ctx, cancel := context.WithTimeout(parent, s.timeout) + defer cancel() + start := time.Now() + err := s.db.PingContext(ctx) + status := ComponentStatus{Elapsed: time.Since(start).String()} + if err != nil { + status.Error = err.Error() + return status + } + status.OK = true + return status +} + +func (s *HealthService) checkRedis(parent context.Context) ComponentStatus { + ctx, cancel := context.WithTimeout(parent, s.timeout) + defer cancel() + start := time.Now() + pong, err := s.rdb.Ping(ctx).Result() + status := ComponentStatus{Elapsed: time.Since(start).String()} + if err != nil { + status.Error = err.Error() + return status + } + if pong != "PONG" { + status.Error = errors.New("unexpected redis ping response: " + pong).Error() + return status + } + status.OK = true + return status +} diff --git a/backend/internal/service/health_service_test.go b/backend/internal/service/health_service_test.go new file mode 100644 index 00000000..0fa8d931 --- /dev/null +++ b/backend/internal/service/health_service_test.go @@ -0,0 +1,93 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestHealthService_Liveness_AlwaysOK(t *testing.T) { + s := NewHealthService(nil, nil) + require.NoError(t, s.Liveness()) +} + +func TestHealthService_Readiness_AllNilReturnsOK(t *testing.T) { + // 当所有依赖都为 nil 时(早期启动或 unit test),readiness 应直接 OK。 + s := NewHealthService(nil, nil) + report := s.Readiness(context.Background()) + require.True(t, report.OK) + require.Empty(t, report.Details) +} + +func TestHealthService_Readiness_DBPingFails(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + mock.ExpectPing().WillReturnError(errors.New("connection refused")) + + s := NewHealthService(db, nil) + report := s.Readiness(context.Background()) + require.False(t, report.OK) + require.Contains(t, report.Details, "database") + require.False(t, report.Details["database"].OK) + require.Contains(t, report.Details["database"].Error, "connection refused") +} + +func TestHealthService_Readiness_DBOK(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + + mock.ExpectPing() + + s := NewHealthService(db, nil) + report := s.Readiness(context.Background()) + require.True(t, report.OK) + require.True(t, report.Details["database"].OK) +} + +func TestHealthService_Readiness_RedisFails(t *testing.T) { + // 指向一个不可达端口让 redis ping 立刻失败。 + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 200 * time.Millisecond, + ReadTimeout: 200 * time.Millisecond, + }) + defer rdb.Close() + + s := NewHealthService(nil, rdb) + s.timeout = 500 * time.Millisecond + report := s.Readiness(context.Background()) + require.False(t, report.OK) + require.Contains(t, report.Details, "redis") + require.False(t, report.Details["redis"].OK) +} + +func TestHealthService_Readiness_PerComponentTimeout(t *testing.T) { + // 验证 readiness 在超时时不会无限挂住。 + db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) + require.NoError(t, err) + defer db.Close() + mock.ExpectPing().WillDelayFor(2 * time.Second) + + s := NewHealthService(db, nil) + s.timeout = 100 * time.Millisecond + + start := time.Now() + report := s.Readiness(context.Background()) + elapsed := time.Since(start) + + require.Less(t, elapsed, 1*time.Second, "readiness should respect per-component timeout") + require.False(t, report.OK) + require.NotEmpty(t, report.Details["database"].Error, "timeout should propagate as an error") +} + +// 抑制未使用包警告(database/sql 在签名里使用)。 +var _ = sql.ErrNoRows diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d79a3531..abf437d5 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -451,6 +451,7 @@ var ProviderSet = wire.NewSet( ProvideSettingService, NewDataManagementService, ProvideBackupService, + NewHealthService, ProvideOpsSystemLogSink, NewOpsService, ProvideOpsMetricsCollector, diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index bb213c76..a5415298 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -168,7 +168,7 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] + test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/ready"] interval: 30s timeout: 10s retries: 3 From 5c8c15cdb1707425350deee50f2e1aae293d8e31 Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 00:43:23 +0800 Subject: [PATCH 3/8] feat(refresh,repo): add singleflight to dedupe concurrent token refresh and unschedulable writes Two anti-thundering-herd improvements: 1. OAuthRefreshAPI.RefreshIfNeeded Wrap the existing distributed-lock + DB-reread + executor.Refresh pipeline in a per-process singleflight keyed by cacheKey+window. Without this, N concurrent goroutines on the same account each pay one Redis lock RTT and one DB reread; with it, only the leader pays and the rest share the result. The refreshWindow is part of the key so a long background-refresh window cannot starve a short foreground-refresh window. 2. accountRepository.SetTempUnschedulable Wrap the same path (UPDATE + scheduler outbox enqueue + scheduler cache sync) in a per-process singleflight keyed by id+until+reason. The SQL guard (existing < new) already makes the UPDATE idempotent, but N callers still cost N round-trips and N outbox inserts. With singleflight, an upstream 401 burst that hits the same account collapses to one execution. Tests cover dedup behavior, key separation by account / refresh window, and that the SQL exec count drops from N to <=2 (UPDATE + outbox). --- backend/internal/repository/account_repo.go | 19 +++ .../account_repo_singleflight_test.go | 119 +++++++++++++ backend/internal/service/oauth_refresh_api.go | 54 +++++- .../oauth_refresh_api_singleflight_test.go | 160 ++++++++++++++++++ 4 files changed, 345 insertions(+), 7 deletions(-) create mode 100644 backend/internal/repository/account_repo_singleflight_test.go create mode 100644 backend/internal/service/oauth_refresh_api_singleflight_test.go diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a12..23db17d1 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -29,6 +29,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + "golang.org/x/sync/singleflight" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqljson" @@ -49,6 +50,13 @@ type accountRepository struct { // Used to proactively sync account snapshot to cache when status changes, // ensuring sticky sessions can promptly detect unavailable accounts. schedulerCache service.SchedulerCache + + // tempUnschedSF 在进程内合并对同一账号的并发 SetTempUnschedulable 调用。 + // 上游 401/限流爆发时,N 个 in-flight 请求会同时调用此方法;底层 SQL + // 已经做了 (until < $1) 的 idempotent 保护,不会重复改 row,但 N 次 + // SQL RTT + N 次 outbox enqueue + N 次缓存同步仍然可观。singleflight + // 把这些并发合并成 1 次实际执行,其余 caller 共享同一结果。 + tempUnschedSF singleflight.Group } var schedulerNeutralExtraKeyPrefixes = []string{ @@ -1029,6 +1037,17 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t } func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + // 进程内合并并发调用:key 包含 until 和 reason,确保不同窗口/原因独立去重。 + // until 用毫秒粒度足够:同一爆发窗口内 caller 算的 until 几乎一致; + // 哪怕略有偏差,SQL 的 (existing < new) 条件保证语义安全。 + sfKey := strconv.FormatInt(id, 10) + ":" + strconv.FormatInt(until.UnixMilli(), 10) + ":" + reason + _, err, _ := r.tempUnschedSF.Do(sfKey, func() (interface{}, error) { + return nil, r.setTempUnschedulableOnce(ctx, id, until, reason) + }) + return err +} + +func (r *accountRepository) setTempUnschedulableOnce(ctx context.Context, id int64, until time.Time, reason string) error { _, err := r.sql.ExecContext(ctx, ` UPDATE accounts SET temp_unschedulable_until = $1, diff --git a/backend/internal/repository/account_repo_singleflight_test.go b/backend/internal/repository/account_repo_singleflight_test.go new file mode 100644 index 00000000..5d5d1201 --- /dev/null +++ b/backend/internal/repository/account_repo_singleflight_test.go @@ -0,0 +1,119 @@ +package repository + +import ( + "context" + "database/sql" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// blockingExecutor 是一个最小化的 sqlExecutor 实现,用于精确控制并发时序。 +// ExecContext 会等待 release 信号才返回,便于让多个 goroutine 集中堆积在 +// singleflight 的同一窗口内。 +type blockingExecutor struct { + mu sync.Mutex + execCalls int32 + queryCalls int32 + release chan struct{} + concurrent int32 + maxObserved int32 +} + +func newBlockingExecutor() *blockingExecutor { + return &blockingExecutor{release: make(chan struct{})} +} + +func (e *blockingExecutor) Release() { close(e.release) } + +func (e *blockingExecutor) ExecContext(_ context.Context, _ string, _ ...any) (sql.Result, error) { + atomic.AddInt32(&e.execCalls, 1) + c := atomic.AddInt32(&e.concurrent, 1) + for { + old := atomic.LoadInt32(&e.maxObserved) + if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) { + break + } + } + defer atomic.AddInt32(&e.concurrent, -1) + <-e.release + return driverResult{}, nil +} + +func (e *blockingExecutor) QueryContext(_ context.Context, _ string, _ ...any) (*sql.Rows, error) { + atomic.AddInt32(&e.queryCalls, 1) + return nil, sql.ErrNoRows +} + +// driverResult 是一个零值 sql.Result,用于测试。 +type driverResult struct{} + +func (driverResult) LastInsertId() (int64, error) { return 0, nil } +func (driverResult) RowsAffected() (int64, error) { return 1, nil } + +func TestSetTempUnschedulable_SingleflightDedupesConcurrentCallers(t *testing.T) { + // 同一账号 + 同一 until + 同一 reason 的 N 个并发调用,应只触发一次实际 + // SQL 路径(UPDATE + outbox INSERT = 2 次 ExecContext)。 + exec := newBlockingExecutor() + repo := newAccountRepositoryWithSQL(nil, exec, nil) + + const callers = 30 + until := time.Now().Add(10 * time.Minute) + const reason = "OAuth 401: invalid_grant" + + var wg sync.WaitGroup + wg.Add(callers) + for i := 0; i < callers; i++ { + go func() { + defer wg.Done() + _ = repo.SetTempUnschedulable(context.Background(), 42, until, reason) + }() + } + + // 等首个 ExecContext 进入阻塞,确认 sf 已聚拢调用。 + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) { + time.Sleep(5 * time.Millisecond) + } + require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), + "singleflight should serialize the SQL call to exactly one in-flight execution") + + exec.Release() + wg.Wait() + + // 1 次 UPDATE + 1 次 outbox INSERT = 2 次 exec;其余 29 个 caller 共享结果。 + require.LessOrEqual(t, atomic.LoadInt32(&exec.execCalls), int32(2), + "expected at most 2 ExecContext calls (UPDATE + outbox), got %d", exec.execCalls) + require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), + "no two SQL execs should run concurrently for the same singleflight key") +} + +func TestSetTempUnschedulable_DifferentAccountsRunInParallel(t *testing.T) { + // 不同 account 应分属不同 sf key,能并行写库。 + exec := newBlockingExecutor() + repo := newAccountRepositoryWithSQL(nil, exec, nil) + + until := time.Now().Add(10 * time.Minute) + var wg sync.WaitGroup + for i := int64(1); i <= 3; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + _ = repo.SetTempUnschedulable(context.Background(), i, until, "different reason") + }() + } + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) { + time.Sleep(5 * time.Millisecond) + } + require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), + "different accounts should be able to write in parallel") + + exec.Release() + wg.Wait() +} diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 5dbba638..545db3ea 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -6,6 +6,8 @@ import ( "log/slog" "strconv" "time" + + "golang.org/x/sync/singleflight" ) // OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器 @@ -29,9 +31,19 @@ type OAuthRefreshResult struct { // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 // 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +// +// 双层去重设计: +// 1. 进程内 singleflight:合并同一 cacheKey 的并发调用(避免 100 个 goroutine +// 都去抢同一把分布式锁、都重读一次 DB)。 +// 2. 跨进程分布式锁(Redis):保证集群范围内只有一个 worker 真正发起 OAuth +// 刷新请求。 +// +// 进程内去重在分布式锁之外做,避免无谓的 Redis RTT;跨进程锁仍是必需的, +// singleflight 解决不了多 pod 同时刷新。 type OAuthRefreshAPI struct { accountRepo AccountRepository tokenCache GeminiTokenCache // 可选,nil = 无锁 + sf singleflight.Group } // NewOAuthRefreshAPI 创建统一刷新 API @@ -42,15 +54,19 @@ func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCac } } -// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token +// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token。 +// +// 同一 cacheKey 在同一进程内并发调用会被 singleflight 合并;只有"领导者" +// 调用会真正进入下层流程,其余调用共享相同的 *OAuthRefreshResult / error。 // // 流程: -// 1. 获取分布式锁 -// 2. 从 DB 重读最新 account(防止使用过时的 refresh_token) -// 3. 二次检查是否仍需刷新 -// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑 -// 5. 设置 _token_version + 更新 DB -// 6. 释放锁 +// 1. singleflight 合并同 cacheKey 并发调用 +// 2. 获取分布式锁(跨进程) +// 3. 从 DB 重读最新 account(防止使用过时的 refresh_token) +// 4. 二次检查是否仍需刷新 +// 5. 调用 executor.Refresh() 执行平台特定刷新逻辑 +// 6. 设置 _token_version + 更新 DB +// 7. 释放锁 func (api *OAuthRefreshAPI) RefreshIfNeeded( ctx context.Context, account *Account, @@ -59,6 +75,30 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) + // singleflight key 同时区分 cacheKey 和 refreshWindow: + // 不同的刷新窗口(前台短窗口 / 后台长窗口)应当分开判断 NeedsRefresh, + // 否则后台长窗口的"已经在刷"会让前台短窗口误以为已刷新而立刻拿到旧值。 + sfKey := cacheKey + "|" + refreshWindow.String() + + v, err, _ := api.sf.Do(sfKey, func() (interface{}, error) { + return api.refreshOnce(ctx, account, executor, refreshWindow, cacheKey) + }) + if err != nil { + return nil, err + } + result, _ := v.(*OAuthRefreshResult) + return result, nil +} + +// refreshOnce 是 RefreshIfNeeded 的实际工作函数,仅由 singleflight 领导者调用。 +// 拆出来便于直接做锁/重读/刷新的单元测试,并避免在 sf.Do 闭包里管理多重 defer。 +func (api *OAuthRefreshAPI) refreshOnce( + ctx context.Context, + account *Account, + executor OAuthRefreshExecutor, + refreshWindow time.Duration, + cacheKey string, +) (*OAuthRefreshResult, error) { // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { diff --git a/backend/internal/service/oauth_refresh_api_singleflight_test.go b/backend/internal/service/oauth_refresh_api_singleflight_test.go new file mode 100644 index 00000000..3b834943 --- /dev/null +++ b/backend/internal/service/oauth_refresh_api_singleflight_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// blockingExecutor 在 Refresh 中等待 release 信号,便于精确控制并发时序。 +type blockingExecutor struct { + refreshAPIExecutorStub + release chan struct{} + concurrent int32 // 当前正在 Refresh 的 goroutine 数 + maxObserved int32 // 观察到的最大并发数 + calls int32 +} + +func (e *blockingExecutor) Refresh(_ context.Context, _ *Account) (map[string]any, error) { + atomic.AddInt32(&e.calls, 1) + c := atomic.AddInt32(&e.concurrent, 1) + for { + old := atomic.LoadInt32(&e.maxObserved) + if c <= old || atomic.CompareAndSwapInt32(&e.maxObserved, old, c) { + break + } + } + defer atomic.AddInt32(&e.concurrent, -1) + + <-e.release + return e.credentials, e.err +} + +func TestOAuthRefreshAPI_SingleflightDedupesConcurrentCallers(t *testing.T) { + // 同一 cacheKey 同时进入 N 个 goroutine,应只触发 1 次 executor.Refresh。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + + api := NewOAuthRefreshAPI(repo, cache) + + const callers = 20 + results := make([]*OAuthRefreshResult, callers) + errs := make([]error, callers) + var wg sync.WaitGroup + wg.Add(callers) + + for i := 0; i < callers; i++ { + i := i + go func() { + defer wg.Done() + r, err := api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute) + results[i] = r + errs[i] = err + }() + } + + // 等所有 goroutine 都进入 sf 闭包,确保它们集中在同一窗口里抢同一 key。 + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) == 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(1), atomic.LoadInt32(&exec.concurrent), "singleflight should serialize callers into one Refresh") + + close(exec.release) + wg.Wait() + + require.Equal(t, int32(1), atomic.LoadInt32(&exec.calls), "executor.Refresh must be called exactly once") + require.Equal(t, int32(1), atomic.LoadInt32(&exec.maxObserved), "no two goroutines should be inside Refresh simultaneously") + + // 所有 caller 应拿到等价结果(不必同实例,singleflight Shared 标志会让多个 caller 共享)。 + for i := 0; i < callers; i++ { + require.NoError(t, errs[i]) + require.NotNil(t, results[i]) + require.True(t, results[i].Refreshed) + } +} + +func TestOAuthRefreshAPI_SingleflightSeparatesDifferentCacheKeys(t *testing.T) { + // 不同账号有不同 cacheKey,应能并行刷新而非互相阻塞。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 1, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + + api := NewOAuthRefreshAPI(repo, cache) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + platform := "p" + string(rune('a'+i)) + wg.Add(1) + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 1, Platform: platform}, exec, 5*time.Minute) + }() + } + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 3 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(3), atomic.LoadInt32(&exec.maxObserved), "different cacheKeys should run in parallel") + + close(exec.release) + wg.Wait() +} + +func TestOAuthRefreshAPI_SingleflightSeparatesDifferentRefreshWindows(t *testing.T) { + // 同 cacheKey 但不同 refreshWindow(前台短窗口 vs 后台长窗口)应分开判断 + // NeedsRefresh,避免后台长窗口的"已经在刷"让前台短窗口拿到旧值。 + repo := &refreshAPIAccountRepo{account: &Account{ID: 42, Platform: "claude"}} + cache := &refreshAPICacheStub{lockResult: true} + + exec := &blockingExecutor{ + refreshAPIExecutorStub: refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "new"}, + }, + release: make(chan struct{}), + } + api := NewOAuthRefreshAPI(repo, cache) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 5*time.Minute) + }() + go func() { + defer wg.Done() + _, _ = api.RefreshIfNeeded(context.Background(), &Account{ID: 42, Platform: "claude"}, exec, 1*time.Hour) + }() + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&exec.concurrent) < 2 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + require.Equal(t, int32(2), atomic.LoadInt32(&exec.maxObserved), "different refreshWindow should not be merged") + + close(exec.release) + wg.Wait() +} From 95814974deb5f48aa11e1bd6424d5ff6e695df67 Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 01:22:54 +0800 Subject: [PATCH 4/8] feat(rpm): add token bucket smoothing for RPM rate limiting - New RPMTokenBucketService: per-account continuous-refill token buckets (rate = rpm/60 tokens/sec, capacity = rpm). No new dependencies. - GatewayService.AcquireRPMToken() delegates to the bucket service. - Gateway handler inserts RPM token wait BEFORE wrapReleaseOnDone in both Gemini and Anthropic dispatch paths; timeout returns 429 and releases slot. - Config: gateway.rpm_smoothing.enabled (default false) + max_wait_ms (default 5000). - 7 unit tests covering: immediate acquire, zero RPM, timeout, wait+refill, context cancel, account isolation, bucket reset on RPM change. --- backend/cmd/server/wire_gen.go | 3 +- backend/internal/config/config.go | 21 +++ backend/internal/handler/gateway_handler.go | 30 +++++ backend/internal/service/gateway_service.go | 14 +- .../service/rpm_token_bucket_service.go | 120 ++++++++++++++++++ .../service/rpm_token_bucket_service_test.go | 108 ++++++++++++++++ backend/internal/service/wire.go | 1 + 7 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 backend/internal/service/rpm_token_bucket_service.go create mode 100644 backend/internal/service/rpm_token_bucket_service_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ffb53780..43ebc292 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -180,7 +180,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + rpmTokenBucketService := service.NewRPMTokenBucketService() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, rpmTokenBucketService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 270a0b98..3d5e151f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -481,6 +481,10 @@ type GatewayConfig struct { // UserMessageQueue: 用户消息串行队列配置 // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` + + // RPMSmoothing: RPM 令牌桶平滑配置 + // 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429 + RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"` } type GatewayAntigravityLSWorkerConfig struct { @@ -535,6 +539,23 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string { return "" } +// RPMSmoothingConfig RPM 令牌桶平滑配置 +type RPMSmoothingConfig struct { + // Enabled: 是否启用 RPM 令牌桶平滑(默认 false) + // 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。 + Enabled bool `mapstructure:"enabled"` + // MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429(默认 5000) + MaxWaitMS int `mapstructure:"max_wait_ms"` +} + +// MaxWait returns the configured wait duration, defaulting to 5s. +func (c *RPMSmoothingConfig) MaxWait() time.Duration { + if c.MaxWaitMS <= 0 { + return 5 * time.Second + } + return time.Duration(c.MaxWaitMS) * time.Millisecond +} + // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 type GatewayOpenAIWSConfig struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9..babb9448 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -383,6 +383,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } + // RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒) + // 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。 + if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmCancel() + if rpmErr != nil { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) + return + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) @@ -605,6 +620,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } + // RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒) + // 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。 + if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmCancel() + if rpmErr != nil { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) + return + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 107e5086..dfe3fe34 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -558,7 +558,8 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制) userGroupRateResolver *userGroupRateResolver userGroupRateCache *gocache.Cache userGroupRateSF singleflight.Group @@ -597,6 +598,7 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + rpmTokenBucketSvc *RPMTokenBucketService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -623,6 +625,7 @@ func NewGatewayService( claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, rpmCache: rpmCache, + rpmTokenBucket: rpmTokenBucketSvc, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), settingService: settingService, modelsListCache: gocache.New(modelsListTTL, time.Minute), @@ -2433,6 +2436,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6 return err } +// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed. +// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit. +func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error { + if s.rpmTokenBucket == nil { + return nil + } + return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait) +} + // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) diff --git a/backend/internal/service/rpm_token_bucket_service.go b/backend/internal/service/rpm_token_bucket_service.go new file mode 100644 index 00000000..dfea5798 --- /dev/null +++ b/backend/internal/service/rpm_token_bucket_service.go @@ -0,0 +1,120 @@ +package service + +import ( + "context" + "errors" + "math" + "sync" + "time" +) + +// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait. +var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot") + +// RPMTokenBucketService provides per-account token buckets for RPM smoothing. +// When an account's RPM budget is exhausted, callers can wait up to a configured +// deadline instead of receiving an immediate 429. The bucket refills continuously +// at rpm/60 tokens per second so requests are distributed evenly over time. +type RPMTokenBucketService struct { + buckets sync.Map // map[int64]*rpmEntry +} + +// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService. +func NewRPMTokenBucketService() *RPMTokenBucketService { + return &RPMTokenBucketService{} +} + +type rpmEntry struct { + bucket *tokenBucket + rpm int +} + +// getBucket returns (or creates) the token bucket for accountID. +// If the account's RPM limit has changed since the bucket was created, the bucket is replaced. +func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket { + if v, ok := s.buckets.Load(accountID); ok { + e := v.(*rpmEntry) + if e.rpm == rpm { + return e.bucket + } + // RPM limit changed — replace with a fresh bucket. + fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)} + s.buckets.Store(accountID, fresh) + return fresh.bucket + } + entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)} + actual, _ := s.buckets.LoadOrStore(accountID, entry) + return actual.(*rpmEntry).bucket +} + +// AcquireWithWait attempts to consume one token for the given account. +// It blocks up to maxWait for a token to become available. +// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded, +// or ctx.Err() if the context is cancelled. +// If rpm <= 0 the call returns immediately with nil. +func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error { + if rpm <= 0 { + return nil + } + bucket := s.getBucket(accountID, rpm) + deadline := time.Now().Add(maxWait) + + for { + ok, waitDur := bucket.tryAcquire() + if ok { + return nil + } + + remaining := time.Until(deadline) + if remaining <= 0 || waitDur > remaining { + return ErrRPMWaitTimeout + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitDur): + // token may be available now; retry + } + } +} + +// tokenBucket is a continuous-refill token bucket for a single account. +type tokenBucket struct { + mu sync.Mutex + tokens float64 + maxTokens float64 + rateSec float64 // tokens refilled per second = rpm / 60 + lastFill time.Time +} + +func newTokenBucket(rpm int) *tokenBucket { + max := float64(rpm) + return &tokenBucket{ + tokens: max, + maxTokens: max, + rateSec: float64(rpm) / 60.0, + lastFill: time.Now(), + } +} + +// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token. +// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available. +func (b *tokenBucket) tryAcquire() (bool, time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(b.lastFill).Seconds() + b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec) + b.lastFill = now + + if b.tokens >= 1.0 { + b.tokens -= 1.0 + return true, 0 + } + + deficit := 1.0 - b.tokens + waitSecs := deficit / b.rateSec + return false, time.Duration(waitSecs * float64(time.Second)) +} diff --git a/backend/internal/service/rpm_token_bucket_service_test.go b/backend/internal/service/rpm_token_bucket_service_test.go new file mode 100644 index 00000000..1710c15d --- /dev/null +++ b/backend/internal/service/rpm_token_bucket_service_test.go @@ -0,0 +1,108 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + // Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately. + for i := 0; i < 60; i++ { + err := svc.AcquireWithWait(ctx, 1, 60, 0) + require.NoError(t, err, "call %d should succeed immediately", i+1) + } +} + +func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) { + svc := NewRPMTokenBucketService() + err := svc.AcquireWithWait(context.Background(), 42, 0, 0) + assert.NoError(t, err) +} + +func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // rpm=1 → 1 token/minute. One call drains the bucket. + err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second) + require.NoError(t, err, "first call should succeed") + + // Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms. + start := time.Now() + err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond) + elapsed := time.Since(start) + assert.ErrorIs(t, err, ErrRPMWaitTimeout) + assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block") +} + +func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully. + for i := 0; i < 120; i++ { + require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0)) + } + + // Next call needs to wait ~500ms for the next token. Give it 2s. + start := time.Now() + err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second) + elapsed := time.Since(start) + require.NoError(t, err, "should succeed after waiting for refill") + assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited") + assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long") +} + +func TestRPMTokenBucket_ContextCancellation(t *testing.T) { + svc := NewRPMTokenBucketService() + + // rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining. + // maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms). + // Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first. + for i := 0; i < 120; i++ { + require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0)) + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second) + elapsed := time.Since(start) + assert.ErrorIs(t, err, context.Canceled) + assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly") +} + +func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // Drain account 1 (rpm=1). + require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0)) + + // Account 2 has its own bucket and should succeed immediately. + err := svc.AcquireWithWait(ctx, 2, 1, 0) + assert.NoError(t, err, "different account should have an independent bucket") +} + +func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // Create bucket with rpm=1 and drain it. + require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0)) + // Bucket now empty with rpm=1. + + // Changing RPM to 60 should reset the bucket to full (60 tokens). + err := svc.AcquireWithWait(ctx, 10, 60, 0) + assert.NoError(t, err, "new RPM should cause bucket recreation") +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index abf437d5..2ce138e0 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -424,6 +424,7 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAnnouncementService, NewAdminService, + NewRPMTokenBucketService, NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, From d535688bfdc72ea41e786c39ac72f42b5585b3e0 Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 01:33:05 +0800 Subject: [PATCH 5/8] feat(context): add proactive context compression for long conversations - New context_compressor.go: pure functions operating on raw JSON body (gjson/sjson pattern). approxTokens uses chars/4 heuristic. - compressMessages: removes oldest messages from front, treating consecutive assistant(tool_use)+user(tool_result) pairs as atomic units to prevent orphaned tool_result blocks. - Hooked into Forward() after StripEmptyTextBlocks, gated on account.Credentials[enable_context_compression]. - Config: gateway.context_compression.max_tokens (default 190000). - 8 unit tests covering: approx tokens, no-op when under budget, oldest-message trimming, tool pair preservation, atomic pair removal, body passthrough, body trimming. --- backend/internal/config/config.go | 18 ++ backend/internal/service/account.go | 11 + .../internal/service/context_compressor.go | 151 ++++++++++++++ .../service/context_compressor_test.go | 195 ++++++++++++++++++ backend/internal/service/gateway_service.go | 6 + 5 files changed, 381 insertions(+) create mode 100644 backend/internal/service/context_compressor.go create mode 100644 backend/internal/service/context_compressor_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3d5e151f..4d116313 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -485,6 +485,10 @@ type GatewayConfig struct { // RPMSmoothing: RPM 令牌桶平滑配置 // 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429 RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"` + + // ContextCompression: 主动上下文压缩配置 + // 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息 + ContextCompression ContextCompressionConfig `mapstructure:"context_compression"` } type GatewayAntigravityLSWorkerConfig struct { @@ -556,6 +560,20 @@ func (c *RPMSmoothingConfig) MaxWait() time.Duration { return time.Duration(c.MaxWaitMS) * time.Millisecond } +// ContextCompressionConfig 主动上下文压缩配置 +type ContextCompressionConfig struct { + // MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000) + MaxTokens int `mapstructure:"max_tokens"` +} + +// GetMaxTokens returns the configured token budget, defaulting to 190 000. +func (c *ContextCompressionConfig) GetMaxTokens() int { + if c.MaxTokens <= 0 { + return 190_000 + } + return c.MaxTokens +} + // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 type GatewayOpenAIWSConfig struct { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a1449ffd..feb1da37 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -771,6 +771,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool { return false } +// IsContextCompressionEnabled returns true if the account has opted into proactive +// context compression. When enabled, the gateway will trim oldest messages before +// dispatch to keep the estimated token count within the configured budget. +func (a *Account) IsContextCompressionEnabled() bool { + if a.Credentials == nil { + return false + } + enabled, _ := a.Credentials["enable_context_compression"].(bool) + return enabled +} + func (a *Account) IsBedrock() bool { return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock } diff --git a/backend/internal/service/context_compressor.go b/backend/internal/service/context_compressor.go new file mode 100644 index 00000000..a3500cda --- /dev/null +++ b/backend/internal/service/context_compressor.go @@ -0,0 +1,151 @@ +package service + +import ( + "encoding/json" + "math" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation). +// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response. +const defaultContextCompressionMaxTokens = 190_000 + +// approxTokens estimates the token count for a string using the chars/4 heuristic. +func approxTokens(s string) int { + return int(math.Ceil(float64(len(s)) / 4.0)) +} + +// compressMessagesInBody trims the oldest messages from the request body so that the +// estimated token count of the messages array fits within maxTokens. +// Returns the original body unchanged if no compression is needed or if parsing fails. +func compressMessagesInBody(body []byte, maxTokens int) []byte { + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.Exists() || !msgsResult.IsArray() { + return body + } + + // Unmarshal to a typed slice for processing. + var messages []map[string]any + if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil { + return body + } + + compressed, changed := compressMessages(messages, maxTokens) + if !changed { + return body + } + + newMsgs, err := json.Marshal(compressed) + if err != nil { + return body + } + updated, err := sjson.SetRawBytes(body, "messages", newMsgs) + if err != nil { + return body + } + return updated +} + +// compressMessages removes the oldest messages from the front of msgs until the +// estimated total token count is at or below maxTokens. +// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically +// to avoid orphaned tool_result blocks. +// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise. +func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) { + if len(msgs) == 0 { + return msgs, false + } + + // Estimate total tokens. + totalTokens := 0 + for _, m := range msgs { + totalTokens += msgTokens(m) + } + if totalTokens <= maxTokens { + return msgs, false + } + + // Build atomic removal units: tool_use+tool_result consecutive pairs are one unit. + type unit struct { + startIdx int + endIdx int // exclusive + tokens int + } + units := make([]unit, 0, len(msgs)) + i := 0 + for i < len(msgs) { + toks := msgTokens(msgs[i]) + if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) { + toks += msgTokens(msgs[i+1]) + units = append(units, unit{i, i + 2, toks}) + i += 2 + } else { + units = append(units, unit{i, i + 1, toks}) + i++ + } + } + + // Remove units from the front until we are within budget. + // Always keep at least the last unit so we never send an empty messages array. + removeCount := 0 + for removeCount < len(units)-1 && totalTokens > maxTokens { + totalTokens -= units[removeCount].tokens + removeCount++ + } + if removeCount == 0 { + return msgs, false + } + + cutIdx := units[removeCount].startIdx + return msgs[cutIdx:], true +} + +// msgTokens estimates token count for a single message using the chars/4 heuristic. +func msgTokens(msg map[string]any) int { + b, err := json.Marshal(msg) + if err != nil { + return 0 + } + return approxTokens(string(b)) +} + +// isAssistantWithToolUse returns true if msg is an assistant message whose content +// contains at least one block with "type": "tool_use". +func isAssistantWithToolUse(msg map[string]any) bool { + role, _ := msg["role"].(string) + if role != "assistant" { + return false + } + return contentContainsType(msg["content"], "tool_use") +} + +// isUserWithToolResult returns true if msg is a user message whose content +// contains at least one block with "type": "tool_result". +func isUserWithToolResult(msg map[string]any) bool { + role, _ := msg["role"].(string) + if role != "user" { + return false + } + return contentContainsType(msg["content"], "tool_result") +} + +// contentContainsType returns true if content (a []any of blocks) contains a block +// whose "type" field equals blockType. +func contentContainsType(content any, blockType string) bool { + blocks, ok := content.([]any) + if !ok { + return false + } + for _, b := range blocks { + block, ok := b.(map[string]any) + if !ok { + continue + } + if t, _ := block["type"].(string); t == blockType { + return true + } + } + return false +} diff --git a/backend/internal/service/context_compressor_test.go b/backend/internal/service/context_compressor_test.go new file mode 100644 index 00000000..bad3983d --- /dev/null +++ b/backend/internal/service/context_compressor_test.go @@ -0,0 +1,195 @@ +package service + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helpers + +func makeMsg(role, text string) map[string]any { + return map[string]any{ + "role": role, + "content": text, + } +} + +func makeToolUseMsg(id string) map[string]any { + return map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "tool_use", + "id": id, + "name": "search", + "input": map[string]any{}, + }, + }, + } +} + +func makeToolResultMsg(toolUseID string) map[string]any { + return map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": toolUseID, + "content": "result text", + }, + }, + } +} + +func toAnySlice(msgs []map[string]any) []any { + out := make([]any, len(msgs)) + for i, m := range msgs { + out[i] = m + } + return out +} + +func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte { + t.Helper() + b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"}) + require.NoError(t, err) + return b +} + +// tests + +func TestApproxTokens(t *testing.T) { + assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token + assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens + assert.Equal(t, 0, approxTokens("")) +} + +func TestCompressMessages_NoCompressionNeeded(t *testing.T) { + msgs := []map[string]any{ + makeMsg("user", "hi"), + makeMsg("assistant", "hello"), + } + result, changed := compressMessages(msgs, 100_000) + assert.False(t, changed) + assert.Len(t, result, 2) +} + +func TestCompressMessages_TrimsOldestMessages(t *testing.T) { + // 10 messages, each large enough to be over a tight budget when combined. + msgs := make([]map[string]any, 10) + for i := range msgs { + role := "user" + if i%2 == 1 { + role = "assistant" + } + msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i)) + } + + // Force compression by using a very small token budget. + result, changed := compressMessages(msgs, 1) + assert.True(t, changed) + // Must keep at least one message (the last). + assert.GreaterOrEqual(t, len(result), 1) + // The remaining messages should be from the tail (newest). + lastOrig := msgs[len(msgs)-1]["content"] + lastResult := result[len(result)-1]["content"] + assert.Equal(t, lastOrig, lastResult) +} + +func TestCompressMessages_PreservesToolUsePairs(t *testing.T) { + // Messages: user → assistant+tool_use → user+tool_result → assistant + msgs := []map[string]any{ + makeMsg("user", "start"), + makeToolUseMsg("tool-1"), + makeToolResultMsg("tool-1"), + makeMsg("assistant", "done"), + } + + // Budget that forces removal of the first non-paired message but keeps the tool pair. + // Estimate total tokens and set budget to force removing only "start" but not the pair. + total := 0 + for _, m := range msgs { + total += msgTokens(m) + } + // Budget: remove "start" but keep tool pair + "done". + startTokens := msgTokens(msgs[0]) + budget := total - startTokens + + result, changed := compressMessages(msgs, budget) + assert.True(t, changed) + + // tool_use and tool_result should both be present or both absent. + hasToolUse := false + hasToolResult := false + for _, m := range result { + if isAssistantWithToolUse(m) { + hasToolUse = true + } + if isUserWithToolResult(m) { + hasToolResult = true + } + } + assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all") +} + +func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) { + // Budget forces removal of the tool pair. + msgs := []map[string]any{ + makeMsg("user", "start"), + makeToolUseMsg("tool-1"), + makeToolResultMsg("tool-1"), + makeMsg("assistant", "final answer after tool use"), + } + + // Budget: only keep the last "assistant" message. + lastTokens := msgTokens(msgs[len(msgs)-1]) + + result, changed := compressMessages(msgs, lastTokens) + assert.True(t, changed) + + // Neither tool_use nor tool_result should remain. + for _, m := range result { + assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair") + assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair") + } +} + +func TestCompressMessagesInBody_NoMessages(t *testing.T) { + body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`) + result := compressMessagesInBody(body, 1) + assert.Equal(t, body, result, "body without messages should be unchanged") +} + +func TestCompressMessagesInBody_UnderBudget(t *testing.T) { + msgs := []map[string]any{makeMsg("user", "hi")} + body := bodyWithMessages(t, msgs) + result := compressMessagesInBody(body, 100_000) + assert.Equal(t, body, result, "body under budget should be unchanged") +} + +func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) { + msgs := make([]map[string]any, 20) + for i := range msgs { + role := "user" + if i%2 == 1 { + role = "assistant" + } + msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i)) + } + body := bodyWithMessages(t, msgs) + + // Force significant compression. + result := compressMessagesInBody(body, 50) + assert.Less(t, len(result), len(body), "compressed body should be smaller") + + // Resulting body should still be valid JSON with a messages array. + var parsed map[string]any + require.NoError(t, json.Unmarshal(result, &parsed)) + resultMsgs, ok := parsed["messages"].([]any) + require.True(t, ok) + assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty") +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index dfe3fe34..23a7ccbc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4182,6 +4182,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. body = StripEmptyTextBlocks(body) + // 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。 + if account.IsContextCompressionEnabled() { + maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens() + body = compressMessagesInBody(body, maxTok) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 setOpsUpstreamRequestBody(c, body) From d1e2d39c263da86815e3bd193bb77fdd05f455db Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 01:48:15 +0800 Subject: [PATCH 6/8] feat(viewer): add real-time request stream WebSocket endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds GET /api/v1/admin/ops/ws/requests — a fan-out WebSocket that pushes per-request metadata (method, path, model, account_id, status, latency_ms) to all connected admin clients the moment each gateway dispatch completes. - service/request_event_bus.go: lock-free pub/sub with non-blocking drop when per-subscriber buffer (64 slots) is full; nil-safe Publish - service/request_event_bus_test.go: 6 tests (basic, fanout, drop, nil, close) - GatewayHandler: records reqStartTime at entry; defer emits RequestEvent on every return; sets status success/error/rate_limited in both Gemini and Anthropic dispatch paths - OpsHandler: accepts *RequestEventBus; wires it to RequestStreamWSHandler - ops_ws_requests_handler.go: subscribes to bus, pushes JSON per event, reuses existing upgrader/conn-limit/ping-pong infrastructure - Route: ws.GET("/requests", ...) alongside existing /ws/qps - wire_gen.go: requestEventBus shared between OpsHandler and GatewayHandler --- backend/cmd/server/wire_gen.go | 5 +- backend/internal/handler/admin/ops_handler.go | 7 +- .../admin/ops_runtime_logging_handler_test.go | 6 +- .../admin/ops_system_log_handler_test.go | 28 +-- .../handler/admin/ops_ws_requests_handler.go | 198 ++++++++++++++++++ backend/internal/handler/gateway_handler.go | 35 ++++ backend/internal/server/routes/admin.go | 3 +- backend/internal/service/request_event_bus.go | 75 +++++++ .../service/request_event_bus_test.go | 100 +++++++++ backend/internal/service/wire.go | 1 + 10 files changed, 435 insertions(+), 23 deletions(-) create mode 100644 backend/internal/handler/admin/ops_ws_requests_handler.go create mode 100644 backend/internal/service/request_event_bus.go create mode 100644 backend/internal/service/request_event_bus_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 43ebc292..6a92ceb6 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -193,7 +193,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) - opsHandler := admin.NewOpsHandler(opsService) + requestEventBus := service.NewRequestEventBus() + opsHandler := admin.NewOpsHandler(opsService, requestEventBus) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) @@ -223,7 +224,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) diff --git a/backend/internal/handler/admin/ops_handler.go b/backend/internal/handler/admin/ops_handler.go index 44accc8f..d9c49250 100644 --- a/backend/internal/handler/admin/ops_handler.go +++ b/backend/internal/handler/admin/ops_handler.go @@ -16,7 +16,8 @@ import ( ) type OpsHandler struct { - opsService *service.OpsService + opsService *service.OpsService + requestEventBus *service.RequestEventBus } // GetErrorLogByID returns ops error log detail. @@ -70,8 +71,8 @@ func parseOpsViewParam(c *gin.Context) string { } } -func NewOpsHandler(opsService *service.OpsService) *OpsHandler { - return &OpsHandler{opsService: opsService} +func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler { + return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus} } // GetErrorLogs lists ops error logs. diff --git a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go index 0e84b4f9..0eede09b 100644 --- a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go +++ b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go @@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService { } func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t)) + h := NewOpsHandler(newRuntimeOpsService(t), nil) r := newOpsRuntimeRouter(h, false) w := httptest.NewRecorder() @@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { } func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t)) + h := NewOpsHandler(newRuntimeOpsService(t), nil) r := newOpsRuntimeRouter(h, false) body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` @@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { } func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t)) + h := NewOpsHandler(newRuntimeOpsService(t), nil) r := newOpsRuntimeRouter(h, true) payload := map[string]any{ diff --git a/backend/internal/handler/admin/ops_system_log_handler_test.go b/backend/internal/handler/admin/ops_system_log_handler_test.go index 7528acd8..a030320c 100644 --- a/backend/internal/handler/admin/ops_system_log_handler_test.go +++ b/backend/internal/handler/admin/ops_system_log_handler_test.go @@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine { } func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { - h := NewOpsHandler(nil) + h := NewOpsHandler(nil, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) { svc := service.NewOpsService(nil, nil, &config.Config{ Ops: config.OpsConfig{Enabled: false}, }, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) { func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, true) w := httptest.NewRecorder() @@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, true) w := httptest.NewRecorder() @@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, true) w := httptest.NewRecorder() @@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) { svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, true) w := httptest.NewRecorder() @@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) { svc := service.NewOpsService(nil, nil, &config.Config{ Ops: config.OpsConfig{Enabled: false}, }, nil, nil, nil, nil, nil, nil, nil, nil) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, true) w := httptest.NewRecorder() @@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) { func TestOpsSystemLogHandler_Health(t *testing.T) { sink := service.NewOpsSystemLogSink(nil) svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) - h := NewOpsHandler(svc) + h := NewOpsHandler(svc, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) { } func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) { - h := NewOpsHandler(nil) + h := NewOpsHandler(nil, nil) r := newOpsSystemLogTestRouter(h, false) w := httptest.NewRecorder() @@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T svc := service.NewOpsService(nil, nil, &config.Config{ Ops: config.OpsConfig{Enabled: false}, }, nil, nil, nil, nil, nil, nil, nil, nil) - h = NewOpsHandler(svc) + h = NewOpsHandler(svc, nil) r = newOpsSystemLogTestRouter(h, false) w = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/logs/health", nil) diff --git a/backend/internal/handler/admin/ops_ws_requests_handler.go b/backend/internal/handler/admin/ops_ws_requests_handler.go new file mode 100644 index 00000000..323018c6 --- /dev/null +++ b/backend/internal/handler/admin/ops_ws_requests_handler.go @@ -0,0 +1,198 @@ +package admin + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +type requestStreamWSMessage struct { + Type string `json:"type"` + Data service.RequestEvent `json:"data"` +} + +// RequestStreamWSHandler streams real-time request events to WebSocket clients. +// GET /api/v1/admin/ops/ws/requests +// +// Each connected client receives a JSON message per gateway dispatch: +// +// {"type":"request_event","data":{"timestamp":...,"method":"POST","path":"/v1/messages", +// "model":"claude-3-5-sonnet-20241022","account_id":42,"status":"success","latency_ms":1230}} +func (h *OpsHandler) RequestStreamWSHandler(c *gin.Context) { + clientIP := requestClientIP(c.Request) + + if h == nil || h.opsService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"}) + return + } + if h.requestEventBus == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "request event bus not initialized"}) + return + } + + if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"}) + return + } + closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled") + return + } + + if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) + return + } + defer func() { + if wsConnCount.Add(-1) == 0 { + scheduleQPSWSIdleStop() + } + }() + + if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { + if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] per-ip limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) + return + } + defer releaseOpsWSIPSlot(clientIP) + } + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] upgrade failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + handleRequestStreamWebSocket(c.Request.Context(), conn, h.requestEventBus) +} + +func handleRequestStreamWebSocket(parentCtx context.Context, conn *websocket.Conn, bus *service.RequestEventBus) { + if conn == nil || bus == nil { + return + } + + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + subID, eventCh := bus.Subscribe() + defer bus.Unsubscribe(subID) + + var closeOnce sync.Once + closeConn := func() { + closeOnce.Do(func() { _ = conn.Close() }) + } + + closeFrameCh := make(chan []byte, 1) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + + conn.SetReadLimit(qpsWSMaxReadBytes) + if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] set read deadline failed: %v", err) + return + } + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)) + }) + conn.SetCloseHandler(func(code int, text string) error { + select { + case closeFrameCh <- websocket.FormatCloseMessage(code, text): + default: + } + cancel() + return nil + }) + + for { + _, _, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] read failed: %v", err) + } + return + } + } + }() + + pingTicker := time.NewTicker(qpsWSPingInterval) + defer pingTicker.Stop() + + writeWithTimeout := func(messageType int, data []byte) error { + if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil { + return err + } + return conn.WriteMessage(messageType, data) + } + + sendClose := func(closeFrame []byte) { + if closeFrame == nil { + closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + } + _ = writeWithTimeout(websocket.CloseMessage, closeFrame) + } + + for { + select { + case evt, ok := <-eventCh: + if !ok { + // channel closed by Unsubscribe + sendClose(nil) + closeConn() + wg.Wait() + return + } + msg, err := json.Marshal(requestStreamWSMessage{Type: "request_event", Data: evt}) + if err != nil { + continue + } + if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] write failed: %v", err) + cancel() + closeConn() + wg.Wait() + return + } + + case <-pingTicker.C: + if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { + logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] ping failed: %v", err) + cancel() + closeConn() + wg.Wait() + return + } + + case closeFrame := <-closeFrameCh: + sendClose(closeFrame) + closeConn() + wg.Wait() + return + + case <-ctx.Done(): + var closeFrame []byte + select { + case closeFrame = <-closeFrameCh: + default: + } + sendClose(closeFrame) + closeConn() + wg.Wait() + return + } + } +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index babb9448..9518cc44 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -47,6 +47,7 @@ type GatewayHandler struct { errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper userMsgQueueHelper *UserMsgQueueHelper + requestEventBus *service.RequestEventBus maxAccountSwitches int maxAccountSwitchesGemini int cfg *config.Config @@ -68,6 +69,7 @@ func NewGatewayHandler( userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, settingService *service.SettingService, + requestEventBus *service.RequestEventBus, ) *GatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 10 @@ -100,6 +102,7 @@ func NewGatewayHandler( errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), userMsgQueueHelper: umqHelper, + requestEventBus: requestEventBus, maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, cfg: cfg, @@ -110,6 +113,7 @@ func NewGatewayHandler( // Messages handles Claude API compatible messages endpoint // POST /v1/messages func (h *GatewayHandler) Messages(c *gin.Context) { + reqStartTime := time.Now() // 从context获取apiKey和user(ApiKeyAuth中间件已设置) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -158,6 +162,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + // 实时请求查看器:记录每次请求的结果(账号、模型、状态、延迟) + var ( + reqEventAccountID int64 + reqEventStatus = "error" + ) + defer func() { + if h.requestEventBus != nil { + h.requestEventBus.Publish(service.RequestEvent{ + Timestamp: reqStartTime, + Method: c.Request.Method, + Path: c.FullPath(), + Model: reqModel, + AccountID: reqEventAccountID, + Status: reqEventStatus, + LatencyMS: time.Since(reqStartTime).Milliseconds(), + }) + } + }() + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { @@ -393,6 +416,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + reqEventAccountID = account.ID + reqEventStatus = "rate_limited" h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) return } @@ -458,6 +483,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // 实时请求查看器:标记 Gemini 路径成功 + reqEventAccountID = account.ID + reqEventStatus = "success" + // RPM 计数递增(Forward 成功后) // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 @@ -630,6 +659,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + reqEventAccountID = account.ID + reqEventStatus = "rate_limited" h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) return } @@ -805,6 +836,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // 实时请求查看器:标记 Anthropic 路径成功 + reqEventAccountID = account.ID + reqEventStatus = "success" + // RPM 计数递增(Forward 成功后) // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e04dae85..0dc698a9 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -141,10 +141,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds) } - // WebSocket realtime (QPS/TPS) + // WebSocket realtime (QPS/TPS and request stream) ws := ops.Group("/ws") { ws.GET("/qps", h.Admin.Ops.QPSWSHandler) + ws.GET("/requests", h.Admin.Ops.RequestStreamWSHandler) } // Error logs (legacy) diff --git a/backend/internal/service/request_event_bus.go b/backend/internal/service/request_event_bus.go new file mode 100644 index 00000000..0664f71f --- /dev/null +++ b/backend/internal/service/request_event_bus.go @@ -0,0 +1,75 @@ +package service + +import ( + "sync" + "sync/atomic" + "time" +) + +const requestEventBufSize = 64 + +// RequestEvent is published for every gateway dispatch completion. +type RequestEvent struct { + Timestamp time.Time `json:"timestamp"` + Method string `json:"method"` + Path string `json:"path"` + Model string `json:"model"` + AccountID int64 `json:"account_id"` + // Status is "success", "error", or "rate_limited". + Status string `json:"status"` + LatencyMS int64 `json:"latency_ms"` +} + +// RequestEventBus is a fan-out hub for real-time request events. +// Publishers call Publish; subscribers call Subscribe/Unsubscribe. +// Each subscriber gets its own buffered channel. If the buffer is full +// the event is dropped for that subscriber (non-blocking publish). +type RequestEventBus struct { + mu sync.RWMutex + subscribers map[uint64]chan RequestEvent + nextID atomic.Uint64 +} + +func NewRequestEventBus() *RequestEventBus { + return &RequestEventBus{ + subscribers: make(map[uint64]chan RequestEvent), + } +} + +// Subscribe registers a new subscriber and returns its ID and a receive-only channel. +func (b *RequestEventBus) Subscribe() (uint64, <-chan RequestEvent) { + id := b.nextID.Add(1) + ch := make(chan RequestEvent, requestEventBufSize) + b.mu.Lock() + b.subscribers[id] = ch + b.mu.Unlock() + return id, ch +} + +// Unsubscribe removes a subscriber and closes its channel. +func (b *RequestEventBus) Unsubscribe(id uint64) { + b.mu.Lock() + ch, ok := b.subscribers[id] + if ok { + delete(b.subscribers, id) + } + b.mu.Unlock() + if ok { + close(ch) + } +} + +// Publish sends an event to all current subscribers without blocking. +func (b *RequestEventBus) Publish(e RequestEvent) { + if b == nil { + return + } + b.mu.RLock() + defer b.mu.RUnlock() + for _, ch := range b.subscribers { + select { + case ch <- e: + default: + } + } +} diff --git a/backend/internal/service/request_event_bus_test.go b/backend/internal/service/request_event_bus_test.go new file mode 100644 index 00000000..9c26912e --- /dev/null +++ b/backend/internal/service/request_event_bus_test.go @@ -0,0 +1,100 @@ +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestEventBus_PublishToSubscriber(t *testing.T) { + bus := NewRequestEventBus() + + id, ch := bus.Subscribe() + defer bus.Unsubscribe(id) + + evt := RequestEvent{Model: "claude-3", Status: "success", LatencyMS: 100} + bus.Publish(evt) + + select { + case got := <-ch: + assert.Equal(t, evt, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") + } +} + +func TestRequestEventBus_MultipleSubscribers(t *testing.T) { + bus := NewRequestEventBus() + + id1, ch1 := bus.Subscribe() + id2, ch2 := bus.Subscribe() + defer bus.Unsubscribe(id1) + defer bus.Unsubscribe(id2) + + evt := RequestEvent{Model: "claude-3", Status: "error"} + bus.Publish(evt) + + for _, ch := range []<-chan RequestEvent{ch1, ch2} { + select { + case got := <-ch: + assert.Equal(t, evt, got) + case <-time.After(time.Second): + t.Fatal("timed out waiting for event on one subscriber") + } + } +} + +func TestRequestEventBus_UnsubscribeClosesChannel(t *testing.T) { + bus := NewRequestEventBus() + id, ch := bus.Subscribe() + + bus.Unsubscribe(id) + + // Channel should be closed. + _, ok := <-ch + assert.False(t, ok, "channel should be closed after Unsubscribe") +} + +func TestRequestEventBus_UnsubscribedMissesEvents(t *testing.T) { + bus := NewRequestEventBus() + id, _ := bus.Subscribe() + bus.Unsubscribe(id) + + // Publish after unsubscribe should not panic. + require.NotPanics(t, func() { + bus.Publish(RequestEvent{Model: "test"}) + }) +} + +func TestRequestEventBus_DropWhenFull(t *testing.T) { + bus := NewRequestEventBus() + id, ch := bus.Subscribe() + defer bus.Unsubscribe(id) + + // Fill the buffer then publish one more — should drop, not block. + evt := RequestEvent{Model: "model", Status: "success"} + for i := 0; i < requestEventBufSize; i++ { + bus.Publish(evt) + } + // This publish should return immediately (dropped). + done := make(chan struct{}) + go func() { + bus.Publish(evt) + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Publish blocked when buffer was full") + } + assert.Len(t, ch, requestEventBufSize) +} + +func TestRequestEventBus_NilSafePublish(t *testing.T) { + var bus *RequestEventBus + require.NotPanics(t, func() { + bus.Publish(RequestEvent{Model: "test"}) + }) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 2ce138e0..b3add8d0 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -425,6 +425,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementService, NewAdminService, NewRPMTokenBucketService, + NewRequestEventBus, NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, From a2ab67f8c771ead234d8194c1e7fd64834fb9a51 Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 03:13:30 +0800 Subject: [PATCH 7/8] feat(scheduler): add P2C + quota-aware scheduling for OpenAI accounts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GetQuotaRemainingFraction() to Account: returns [0,1] fraction of remaining quota; 1.0 when no limit is configured (unlimited accounts) - Add Quota float64 weight field to GatewayOpenAIWSSchedulerScoreWeights and EnableP2CScheduling bool to GatewayOpenAIWSConfig (both default off) - Extend selectByLoadBalance scoring with quota factor (gated by Quota>0) - Add selectByPowerOfTwo(): O(1) P2C selection — samples 2 random candidates, tries the better-scored one first then the other, falls back to wait plan; activated when EnableP2CScheduling=true - Add openAIWSP2CEnabled() helper on OpenAIGatewayService - Add 6 tests covering quota fraction edge cases, P2C toggle, weight defaults, single-candidate P2C, two-candidate P2C selection, and quota score ordering --- backend/internal/config/config.go | 7 +- backend/internal/service/account.go | 18 +++ .../service/openai_account_scheduler.go | 103 +++++++++++++- .../service/openai_account_scheduler_test.go | 126 ++++++++++++++++++ 4 files changed, 252 insertions(+), 2 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 4d116313..276d76b9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -657,6 +657,8 @@ type GatewayOpenAIWSConfig struct { // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + // EnableP2CScheduling: 启用 Power-of-Two-Choices 调度(默认 false,使用 top-K 加权随机) + EnableP2CScheduling bool `mapstructure:"enable_p2c_scheduling"` SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` } @@ -667,6 +669,8 @@ type GatewayOpenAIWSSchedulerScoreWeights struct { Queue float64 `mapstructure:"queue"` ErrorRate float64 `mapstructure:"error_rate"` TTFT float64 `mapstructure:"ttft"` + // Quota: 剩余配额比例权重(0 表示不参与打分) + Quota float64 `mapstructure:"quota"` } // GatewayUsageRecordConfig 使用量记录异步队列配置 @@ -2197,7 +2201,8 @@ func (c *Config) Validate() error { c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || - c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Quota < 0 { return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") } weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index feb1da37..9f875f23 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1305,6 +1305,24 @@ func (a *Account) GetQuotaUsed() float64 { return a.getExtraFloat64("quota_used") } +// GetQuotaRemainingFraction returns the fraction of total quota remaining in [0,1]. +// Returns 1.0 when no quota limit is set (limit == 0 means unlimited). +func (a *Account) GetQuotaRemainingFraction() float64 { + limit := a.GetQuotaLimit() + if limit <= 0 { + return 1.0 + } + used := a.GetQuotaUsed() + remaining := (limit - used) / limit + if remaining < 0 { + return 0 + } + if remaining > 1 { + return 1 + } + return remaining +} + // GetQuotaDailyLimit 获取日额度限制(美元),0 表示未启用 func (a *Account) GetQuotaDailyLimit() float64 { return a.getExtraFloat64("quota_daily_limit") diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 37e7ed2c..0b80a80e 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -672,12 +672,18 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) } + quotaFactor := item.account.GetQuotaRemainingFraction() item.score = weights.Priority*priorityFactor + weights.Load*loadFactor + weights.Queue*queueFactor + weights.ErrorRate*errorFactor + - weights.TTFT*ttftFactor + weights.TTFT*ttftFactor + + weights.Quota*quotaFactor + } + + if s.service.openAIWSP2CEnabled() { + return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew) } topK := s.service.openAIWSLBTopK() @@ -888,6 +894,7 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + Quota: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota, } } return GatewayOpenAIWSSchedulerScoreWeightsView{ @@ -896,15 +903,21 @@ func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedul Queue: 0.7, ErrorRate: 0.8, TTFT: 0.5, + Quota: 0.0, } } +func (s *OpenAIGatewayService) openAIWSP2CEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EnableP2CScheduling +} + type GatewayOpenAIWSSchedulerScoreWeightsView struct { Priority float64 Load float64 Queue float64 ErrorRate float64 TTFT float64 + Quota float64 } func clamp01(value float64) float64 { @@ -918,6 +931,94 @@ func clamp01(value float64) float64 { } } +// selectByPowerOfTwo implements Power-of-Two-Choices (P2C): sample 2 random +// candidates and attempt the better-scored one first, then the other. +// This gives O(1) selection with load distribution comparable to top-K when N is large. +func (s *defaultOpenAIAccountScheduler) selectByPowerOfTwo( + ctx context.Context, + req OpenAIAccountScheduleRequest, + candidates []openAIAccountCandidateScore, + loadSkew float64, +) (*AccountSelectionResult, int, int, float64, error) { + n := len(candidates) + if n == 0 { + return nil, 0, 0, loadSkew, ErrNoAvailableAccounts + } + + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + + // Pick two distinct random indices. + idxA := int(rng.nextUint64() % uint64(n)) + idxB := idxA + if n > 1 { + for idxB == idxA { + idxB = int(rng.nextUint64() % uint64(n)) + } + } + + // Order: better candidate first. + first, second := candidates[idxA], candidates[idxB] + if isOpenAIAccountCandidateBetter(second, first) { + first, second = second, first + } + + tryAcquire := func(c openAIAccountCandidateScore) (*AccountSelectionResult, bool, error) { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + return nil, false, nil + } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + return nil, false, nil + } + result, err := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err != nil { + return nil, false, err + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true, nil + } + return nil, false, nil + } + + for _, c := range []openAIAccountCandidateScore{first, second} { + result, ok, err := tryAcquire(c) + if err != nil { + return nil, n, 2, loadSkew, err + } + if ok { + return result, n, 2, loadSkew, nil + } + } + + // Both slots busy — return wait plan on the better candidate. + cfg := s.service.schedulingConfig() + for _, c := range []openAIAccountCandidateScore{first, second} { + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, c.account, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } + return &AccountSelectionResult{ + Account: fresh, + WaitPlan: &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, n, 2, loadSkew, nil + } + + return nil, n, 2, loadSkew, ErrNoAvailableAccounts +} + func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { if count <= 1 { return 0 diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed..35d26e2b 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -966,3 +966,129 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * func int64PtrForTest(v int64) *int64 { return &v } + +func TestAccount_GetQuotaRemainingFraction(t *testing.T) { + // No limit configured → always 1.0 (unlimited) + noLimit := &Account{} + require.Equal(t, 1.0, noLimit.GetQuotaRemainingFraction()) + + // 50% used + half := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 50.0}} + require.InDelta(t, 0.5, half.GetQuotaRemainingFraction(), 1e-9) + + // Fully exhausted + full := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 100.0}} + require.Equal(t, 0.0, full.GetQuotaRemainingFraction()) + + // Over limit → clamp to 0 + over := &Account{Extra: map[string]any{"quota_limit": 100.0, "quota_used": 150.0}} + require.Equal(t, 0.0, over.GetQuotaRemainingFraction()) + + // Fresh (0 used) + fresh := &Account{Extra: map[string]any{"quota_limit": 200.0, "quota_used": 0.0}} + require.Equal(t, 1.0, fresh.GetQuotaRemainingFraction()) +} + +func TestOpenAIGatewayService_P2CEnabled(t *testing.T) { + require.False(t, (*OpenAIGatewayService)(nil).openAIWSP2CEnabled()) + require.False(t, (&OpenAIGatewayService{}).openAIWSP2CEnabled()) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.EnableP2CScheduling = false + require.False(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled()) + + cfg.Gateway.OpenAIWS.EnableP2CScheduling = true + require.True(t, (&OpenAIGatewayService{cfg: cfg}).openAIWSP2CEnabled()) +} + +func TestOpenAIGatewayService_SchedulerWeights_QuotaField(t *testing.T) { + // Default weights: Quota is 0 (disabled by default) + svc := &OpenAIGatewayService{} + weights := svc.openAIWSSchedulerWeights() + require.Equal(t, 0.0, weights.Quota) + + // Config-driven quota weight + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Quota = 0.4 + svcWithCfg := &OpenAIGatewayService{cfg: cfg} + require.Equal(t, 0.4, svcWithCfg.openAIWSSchedulerWeights().Quota) +} + +func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_SingleCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(99001) + account := &Account{ID: 71001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{account}, + accountsByID: map[int64]*Account{71001: account}, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.EnableP2CScheduling = true + cfg.Gateway.OpenAIWS.LBTopK = 5 + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, + cfg: cfg, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.Equal(t, int64(71001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestDefaultOpenAIAccountScheduler_SelectByPowerOfTwo_PicksBetterCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(99002) + // Account A has low priority (better), B has high priority (worse). + // With P2C enabled and a deterministic seed, we should always get a valid selection. + accountA := &Account{ID: 72001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0} + accountB := &Account{ID: 72002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 10} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{accountA, accountB}, + accountsByID: map[int64]*Account{72001: accountA, 72002: accountB}, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.EnableP2CScheduling = true + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{*accountA, *accountB}}, + cfg: cfg, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-4o", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // Either account is valid; just verify we got a schedulable one. + require.True(t, selection.Account.ID == 72001 || selection.Account.ID == 72002) +} + +func TestDefaultOpenAIAccountScheduler_QuotaFactorInfluencesScore(t *testing.T) { + // Verify that quota weight affects scoring by checking GetQuotaRemainingFraction is used. + // Account with high remaining quota should score higher when quota weight > 0. + highQuota := &Account{ + ID: 73001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + Extra: map[string]any{"quota_limit": 100.0, "quota_used": 10.0}, // 90% remaining + } + lowQuota := &Account{ + ID: 73002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + Extra: map[string]any{"quota_limit": 100.0, "quota_used": 90.0}, // 10% remaining + } + + require.InDelta(t, 0.9, highQuota.GetQuotaRemainingFraction(), 1e-9) + require.InDelta(t, 0.1, lowQuota.GetQuotaRemainingFraction(), 1e-9) + + // With quota weight = 1.0 and all other weights = 0, high-quota account should win. + // We verify the score ordering directly using isOpenAIAccountCandidateBetter. + highScore := openAIAccountCandidateScore{account: highQuota, score: 0.9} + lowScore := openAIAccountCandidateScore{account: lowQuota, score: 0.1} + require.True(t, isOpenAIAccountCandidateBetter(highScore, lowScore)) + require.False(t, isOpenAIAccountCandidateBetter(lowScore, highScore)) +} From 5123d92b44663052687feddac52be29f5b7bf905 Mon Sep 17 00:00:00 2001 From: win Date: Wed, 29 Apr 2026 03:23:39 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat(scheduling):=20add=20cross-tier=20fall?= =?UTF-8?q?back=20chain=20(subscription=20=E2=86=92=20API=20Key=20?= =?UTF-8?q?=E2=86=92=20Bedrock)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an opt-in tier-based fallback scheduling path for Anthropic accounts: - accountTierLevel(): derives tier from account type without DB migration (tier-0=OAuth/SetupToken, tier-1=APIKey, tier-2=Bedrock) - enableTierFallbackChain(): new config flag gateway.scheduling.enable_tier_fallback_chain (default false) - selectAccountWithTierFallback(): loads all Anthropic accounts, groups by tier, honors sticky sessions, applies all existing schedulability guards, then tries tiers 0→1→2 in order via tryAcquireByLegacyOrder - Wired into SelectAccountForModelWithExclusions: Anthropic platform + tier fallback enabled → calls new path instead of mixed scheduling - Fix pre-existing unit-test build break: NewGatewayService now requires *RPMTokenBucketService (added in Task #5); add missing nil param - 7 tests: tier mapping, config toggle, subscription preference, APIKey fallback, exclusion handling, empty-pool error, Bedrock last resort --- backend/internal/config/config.go | 4 + .../service/gateway_record_usage_test.go | 1 + backend/internal/service/gateway_service.go | 3 + .../internal/service/gateway_tier_fallback.go | 133 +++++++++++++++++ .../service/gateway_tier_fallback_test.go | 138 ++++++++++++++++++ 5 files changed, 279 insertions(+) create mode 100644 backend/internal/service/gateway_tier_fallback.go create mode 100644 backend/internal/service/gateway_tier_fallback_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 276d76b9..17f45b74 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -810,6 +810,10 @@ type GatewaySchedulingConfig struct { // 全量重建周期配置 // 全量重建周期(秒),0 表示禁用 FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"` + + // EnableTierFallbackChain: 启用跨档降级链(订阅 → API Key → Bedrock),默认 false + // 仅对 Anthropic 平台生效;启用后账号按类型分层,优先使用订阅账号,依次降级。 + EnableTierFallbackChain bool `mapstructure:"enable_tier_fallback_chain"` } func (s *ServerConfig) Address() string { diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 48488dc8..5df0b58c 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -41,6 +41,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 23a7ccbc..7e238850 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1192,6 +1192,9 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { + if platform == PlatformAnthropic && s.enableTierFallbackChain() { + return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs) + } return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } diff --git a/backend/internal/service/gateway_tier_fallback.go b/backend/internal/service/gateway_tier_fallback.go new file mode 100644 index 00000000..35e7af36 --- /dev/null +++ b/backend/internal/service/gateway_tier_fallback.go @@ -0,0 +1,133 @@ +package service + +import ( + "context" + "errors" +) + +// accountTierLevel maps an account type to a scheduling tier: +// +// 0 = subscription (OAuth / SetupToken) — tried first +// 1 = API Key — first fallback +// 2 = Bedrock — last resort +// +// Accounts with an unknown type fall into tier 0 so they participate in the +// primary selection and do not vanish silently. +func accountTierLevel(account *Account) int { + if account == nil { + return 0 + } + switch account.Type { + case AccountTypeAPIKey: + return 1 + case AccountTypeBedrock: + return 2 + default: // OAuth, SetupToken, or unknown + return 0 + } +} + +// enableTierFallbackChain reports whether the cross-tier fallback chain is +// enabled in config (default false). +func (s *GatewayService) enableTierFallbackChain() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.Scheduling.EnableTierFallbackChain +} + +// selectAccountWithTierFallback tries Anthropic accounts in tier order: +// tier 0 (OAuth/SetupToken subscription) → tier 1 (API Key) → tier 2 (Bedrock). +// +// Sticky sessions are honored within the chain: if the session-bound account is +// in a tier that still has capacity it is returned immediately; otherwise the +// session binding is cleared and the chain proceeds from tier 0. +func (s *GatewayService) selectAccountWithTierFallback( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*Account, error) { + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAnthropic, false) + if err != nil { + return nil, err + } + + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + // Build per-tier candidate lists (pointers into `accounts`). + const numTiers = 3 + tierCandidates := [numTiers][]*Account{} + for i := range accounts { + acc := &accounts[i] + if acc.Platform != PlatformAnthropic { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + if !s.isAccountSchedulableForSelection(acc) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + tier := accountTierLevel(acc) + if tier < numTiers { + tierCandidates[tier] = append(tierCandidates[tier], acc) + } + } + + cfg := s.schedulingConfig() + selectionMode := cfg.FallbackSelectionMode + + // Check sticky session: if the bound account is a valid candidate, use it. + if sessionHash != "" && s.cache != nil { + accountID, cacheErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if cacheErr == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + for tier := 0; tier < numTiers; tier++ { + for _, acc := range tierCandidates[tier] { + if acc.ID != accountID { + continue + } + if shouldClearStickySession(acc, requestedModel) { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + break + } + if s.isAccountSchedulableForWindowCost(ctx, acc, true) && + s.isAccountSchedulableForRPM(ctx, acc, true) { + return acc, nil + } + } + } + } + } + } + + // Try each tier in order. + for tier := 0; tier < numTiers; tier++ { + candidates := tierCandidates[tier] + if len(candidates) == 0 { + continue + } + s.sortCandidatesForFallback(candidates, false, selectionMode) + result, acquired := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, false) + if acquired && result != nil { + return result.Account, nil + } + } + + return nil, errors.New("no available accounts in any tier") +} diff --git a/backend/internal/service/gateway_tier_fallback_test.go b/backend/internal/service/gateway_tier_fallback_test.go new file mode 100644 index 00000000..50a0ba9b --- /dev/null +++ b/backend/internal/service/gateway_tier_fallback_test.go @@ -0,0 +1,138 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestAccountTierLevel(t *testing.T) { + require.Equal(t, 0, accountTierLevel(nil)) + require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeOAuth})) + require.Equal(t, 0, accountTierLevel(&Account{Type: AccountTypeSetupToken})) + require.Equal(t, 0, accountTierLevel(&Account{Type: "unknown"})) + require.Equal(t, 1, accountTierLevel(&Account{Type: AccountTypeAPIKey})) + require.Equal(t, 2, accountTierLevel(&Account{Type: AccountTypeBedrock})) +} + +func TestGatewayService_EnableTierFallbackChain(t *testing.T) { + require.False(t, (*GatewayService)(nil).enableTierFallbackChain()) + require.False(t, (&GatewayService{}).enableTierFallbackChain()) + + cfgOff := &config.Config{} + cfgOff.Gateway.Scheduling.EnableTierFallbackChain = false + require.False(t, (&GatewayService{cfg: cfgOff}).enableTierFallbackChain()) + + cfgOn := &config.Config{} + cfgOn.Gateway.Scheduling.EnableTierFallbackChain = true + require.True(t, (&GatewayService{cfg: cfgOn}).enableTierFallbackChain()) +} + +// TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription verifies +// that when both OAuth (subscription) and APIKey accounts are available, the +// tier-0 OAuth account is always selected first even if APIKey has higher priority. +func TestGatewayService_SelectAccountWithTierFallback_PrefersSubscription(t *testing.T) { + ctx := context.Background() + + oauthAcc := Account{ID: 91001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Priority: 5} + apiKeyAcc := Account{ID: 91002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Priority: 0} + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{oauthAcc, apiKeyAcc}, + accountsByID: map[int64]*Account{91001: &oauthAcc, 91002: &apiKeyAcc}, + } + cache := &mockGatewayCacheForPlatform{} + svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()} + + acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(91001), acc.ID, "OAuth (tier-0) account should be preferred over APIKey (tier-1)") +} + +// TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey verifies +// that when the subscription tier has no schedulable accounts, the fallback +// selects an API Key account. +func TestGatewayService_SelectAccountWithTierFallback_FallsBackToAPIKey(t *testing.T) { + ctx := context.Background() + + rateLimitedUntil := time.Now().Add(30 * time.Minute) + oauthAcc := Account{ID: 92001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil} + apiKeyAcc := Account{ID: 92002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true} + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{oauthAcc, apiKeyAcc}, + accountsByID: map[int64]*Account{92001: &oauthAcc, 92002: &apiKeyAcc}, + } + cache := &mockGatewayCacheForPlatform{} + svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()} + + acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(92002), acc.ID, "Should fall back to APIKey when OAuth is rate-limited") +} + +// TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts ensures +// excluded IDs are respected across all tiers. +func TestGatewayService_SelectAccountWithTierFallback_ExcludesAccounts(t *testing.T) { + ctx := context.Background() + + oauthAcc := Account{ID: 93001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true} + apiKeyAcc := Account{ID: 93002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true} + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{oauthAcc, apiKeyAcc}, + accountsByID: map[int64]*Account{93001: &oauthAcc, 93002: &apiKeyAcc}, + } + cache := &mockGatewayCacheForPlatform{} + svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()} + + excluded := map[int64]struct{}{93001: {}} + acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", excluded) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(93002), acc.ID, "Excluded OAuth account should cause APIKey fallback") +} + +// TestGatewayService_SelectAccountWithTierFallback_NoAccounts verifies that +// an error is returned when all tiers are empty. +func TestGatewayService_SelectAccountWithTierFallback_NoAccounts(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{accounts: nil, accountsByID: map[int64]*Account{}} + cache := &mockGatewayCacheForPlatform{} + svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()} + + acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil) + require.Error(t, err) + require.Nil(t, acc) +} + +// TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort verifies +// that Bedrock accounts are only used when subscription and API Key tiers are exhausted. +func TestGatewayService_SelectAccountWithTierFallback_BedrockLastResort(t *testing.T) { + ctx := context.Background() + + rateLimitedUntil := time.Now().Add(30 * time.Minute) + oauthAcc := Account{ID: 94001, Platform: PlatformAnthropic, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil} + apiKeyAcc := Account{ID: 94002, Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, RateLimitResetAt: &rateLimitedUntil} + bedrockAcc := Account{ID: 94003, Platform: PlatformAnthropic, Type: AccountTypeBedrock, Status: StatusActive, Schedulable: true} + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{oauthAcc, apiKeyAcc, bedrockAcc}, + accountsByID: map[int64]*Account{94001: &oauthAcc, 94002: &apiKeyAcc, 94003: &bedrockAcc}, + } + cache := &mockGatewayCacheForPlatform{} + svc := &GatewayService{accountRepo: repo, cache: cache, cfg: testConfig()} + + acc, err := svc.selectAccountWithTierFallback(ctx, nil, "", "", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(94003), acc.ID, "Bedrock should be selected as last resort") +}