From 496469ac4e22a90f417ce1f1b48ff8868f938183 Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 25 Apr 2026 22:50:35 +0800 Subject: [PATCH 01/46] fix(gateway): skip body mimicry for real Claude Code clients to restore prompt caching PR #1914 unconditionally applied the full mimicry pipeline to all OAuth accounts, including real Claude Code CLI clients. This replaced the client's long system prompt (~10K+ tokens with stable cache_control breakpoints) with a short ~45 token [billing, CC prompt] pair, which falls below Anthropic's 1024-token minimum cacheable prefix threshold. The result: every request created a new cache but never hit an existing one. Fix: restore the Claude Code client detection gate so that real CC clients bypass body-level mimicry (system rewrite, message cache management, tool name obfuscation). Non-CC third-party clients (opencode, etc.) continue to receive full mimicry. Also harden the detection logic: - Make UA regex case-insensitive (align with claude_code_validator.go) - Validate metadata.user_id format via ParseMetadataUserID() instead of just checking non-empty, preventing third-party tools from spoofing a claude-cli/* UA with an arbitrary user_id string to bypass mimicry --- .../internal/service/gateway_prompt_test.go | 41 +++++++++++++++---- backend/internal/service/gateway_service.go | 34 +++++++++------ 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index 443486ab..f3a22c1d 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -9,6 +9,11 @@ import ( ) func TestIsClaudeCodeClient(t *testing.T) { + // 合法的 legacy 格式 metadata.user_id(64位 hex + account uuid + session uuid) + legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000" + // 合法的 JSON 格式 metadata.user_id(2.1.78+ 版本) + jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}` + tests := []struct { name string userAgent string @@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) { want bool }{ { - name: "Claude Code client", + name: "Claude Code client with legacy user_id", userAgent: "claude-cli/1.0.62 (darwin; arm64)", - metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + metadataUserID: legacyUserID, want: true, }, { - name: "Claude Code without version suffix", - userAgent: "claude-cli/2.0.0", - metadataUserID: "session_abc", + name: "Claude Code client with JSON user_id", + userAgent: "claude-cli/2.1.92 (external, cli)", + metadataUserID: jsonUserID, + want: true, + }, + { + name: "Claude Code case insensitive UA", + userAgent: "Claude-CLI/2.0.0", + metadataUserID: legacyUserID, want: true, }, { @@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) { want: false, }, { - name: "Different user agent", + name: "Claude CLI UA with invalid user_id format", + userAgent: "claude-cli/2.0.0", + metadataUserID: "fake-user-id-12345", + want: false, + }, + { + name: "Different user agent with valid user_id", userAgent: "curl/7.68.0", - metadataUserID: "user123", + metadataUserID: legacyUserID, want: false, }, { name: "Empty user agent", userAgent: "", - metadataUserID: "user123", + metadataUserID: legacyUserID, want: false, }, { name: "Similar but not Claude CLI", userAgent: "claude-api/1.0.0", - metadataUserID: "user123", + metadataUserID: legacyUserID, + want: false, + }, + { + name: "Opencode spoofing UA with arbitrary user_id", + userAgent: "claude-cli/2.1.92", + metadataUserID: "session_abc", want: false, }, } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index ffd66fc7..6be19ba6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -329,7 +329,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) - claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -3709,13 +3709,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { } } -// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 -// 简化判断:User-Agent 匹配 + metadata.user_id 存在 +// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。 +// 判定条件: +// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感) +// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式) +// +// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA +// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID +// 验证格式才能确认是真正的 Claude Code 客户端。 func isClaudeCodeClient(userAgent string, metadataUserID string) bool { - if metadataUserID == "" { + if !claudeCliUserAgentRe.MatchString(userAgent) { return false } - return claudeCliUserAgentRe.MatchString(userAgent) + return ParseMetadataUserID(metadataUserID) != nil } // normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), @@ -4144,12 +4150,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) } - // OAuth 账号无条件走完整 mimicry,与 Parrot 对齐。 - // 不再检查 isClaudeCodeRequest —— 即使客户端自称 Claude Code(opencode 等 - // 第三方工具会伪装 UA / X-App / system prompt),它的伪装往往不完整(缺 billing - // block / 工具名混淆 / cache 策略等),被 Anthropic 判为 third-party。 - // 无条件覆盖不会对真正的 Claude Code 造成问题,因为我们的伪装更完整。 - shouldMimicClaudeCode := account.IsOAuth() + // Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。 + // 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header, + // 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略 + // (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token + // 最低缓存门槛,导致系统级缓存失效)。 + // + // 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。 + isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { // 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code @@ -8387,7 +8396,8 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // Pre-filter: strip empty text blocks to prevent upstream 400. body = StripEmptyTextBlocks(body) - shouldMimicClaudeCode := account.IsOAuth() + isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT if shouldMimicClaudeCode { normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} From b17704d6effc717e5644ad09f61abe9aa2296775 Mon Sep 17 00:00:00 2001 From: deqiying Date: Sun, 26 Apr 2026 01:14:59 +0800 Subject: [PATCH 02/46] =?UTF-8?q?fix(anthropic):=20=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=E7=BC=93=E5=AD=98=20token=20=E7=9A=84=20Anthropic=20=E7=94=A8?= =?UTF-8?q?=E9=87=8F=E8=AF=AD=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pkg/apicompat/anthropic_responses_test.go | 79 +++++++++++++++++++ .../pkg/apicompat/responses_to_anthropic.go | 39 ++++++--- 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 095305c2..c35b51b6 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) { assert.Equal(t, 5, anth.Usage.OutputTokens) } +func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cached", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Cached response"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 54006, + OutputTokens: 123, + TotalTokens: 54129, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 50688, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929") + assert.Equal(t, 3318, anth.Usage.InputTokens) + assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens) + assert.Equal(t, 123, anth.Usage.OutputTokens) +} + +func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cached_clamp", + Model: "gpt-5.2", + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 5, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 150, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929") + assert.Equal(t, 0, anth.Usage.InputTokens) + assert.Equal(t, 150, anth.Usage.CacheReadInputTokens) + assert.Equal(t, 5, anth.Usage.OutputTokens) +} + func TestResponsesToAnthropic_ToolUse(t *testing.T) { resp := &ResponsesResponse{ ID: "resp_456", @@ -343,6 +392,36 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { + state := NewResponsesEventToAnthropicState() + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 54006, + OutputTokens: 123, + TotalTokens: 54129, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 50688, + }, + }, + }, + }, state) + + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, 3318, events[0].Usage.InputTokens) + assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens) + assert.Equal(t, 123, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) +} + func TestStreamingToolCall(t *testing.T) { state := NewResponsesEventToAnthropicState() diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 5409a0f4..40bed302 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks) if resp.Usage != nil { - out.Usage = AnthropicUsage{ - InputTokens: resp.Usage.InputTokens, - OutputTokens: resp.Usage.OutputTokens, - } - if resp.Usage.InputTokensDetails != nil { - out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens - } + out.Usage = anthropicUsageFromResponsesUsage(resp.Usage) } return out } +func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage { + if usage == nil { + return AnthropicUsage{} + } + + cachedTokens := 0 + if usage.InputTokensDetails != nil { + cachedTokens = usage.InputTokensDetails.CachedTokens + } + + inputTokens := usage.InputTokens - cachedTokens + if inputTokens < 0 { + inputTokens = 0 + } + + return AnthropicUsage{ + InputTokens: inputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: cachedTokens, + } +} + func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string { switch status { case "incomplete": @@ -466,11 +482,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo stopReason := "end_turn" if evt.Response != nil { if evt.Response.Usage != nil { - state.InputTokens = evt.Response.Usage.InputTokens - state.OutputTokens = evt.Response.Usage.OutputTokens - if evt.Response.Usage.InputTokensDetails != nil { - state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens - } + usage := anthropicUsageFromResponsesUsage(evt.Response.Usage) + state.InputTokens = usage.InputTokens + state.OutputTokens = usage.OutputTokens + state.CacheReadInputTokens = usage.CacheReadInputTokens } switch evt.Response.Status { case "incomplete": From 489a4d934e4f91560bdc73e91ac91dd133b5b3b1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sat, 25 Apr 2026 19:46:32 -0400 Subject: [PATCH 03/46] Show today stats for Vertex usage window --- .../components/account/AccountUsageCell.vue | 35 +++++++++++++ .../__tests__/AccountUsageCell.spec.ts | 52 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 1c023fb3..2c04e673 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -332,6 +332,37 @@
+
+
+ + {{ formatKeyRequests }} req + + + {{ formatKeyTokens }} + + + A ${{ formatKeyCost }} + + + U ${{ formatKeyUserCost }} + +
+
+
+
+
+
+
@@ -512,6 +543,10 @@ const shouldFetchUsage = computed(() => { return false }) +const showGeminiTodayStats = computed(() => { + return props.account.platform === 'gemini' && props.account.type === 'service_account' +}) + const geminiUsageAvailable = computed(() => { return ( !!usageInfo.value?.gemini_shared_daily || diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts index 9158da64..fa4104f6 100644 --- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts +++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts @@ -57,6 +57,19 @@ function makeAccount(overrides: Partial): Account { describe('AccountUsageCell', () => { beforeEach(() => { getUsage.mockReset() + Object.defineProperty(window, 'matchMedia', { + writable: true, + value: vi.fn().mockImplementation(() => ({ + matches: true, + media: '(min-width: 768px)', + onchange: null, + addListener: vi.fn(), + removeListener: vi.fn(), + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + dispatchEvent: vi.fn(), + })) + }) }) it('Antigravity 图片用量会聚合新旧 image 模型', async () => { @@ -603,4 +616,43 @@ describe('AccountUsageCell', () => { expect(wrapper.text().trim()).toBe('-') }) + + it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => { + const wrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 4001, + platform: 'gemini', + type: 'service_account', + credentials: { + tier_id: 'vertex', + project_id: 'vertex-proj', + client_email: 'svc@vertex-proj.iam.gserviceaccount.com', + location: 'global' + }, + extra: {} + }), + todayStats: { + requests: 0, + tokens: 0, + cost: 0, + standard_cost: 0, + user_cost: 0 + } + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('0 req') + expect(wrapper.text()).toContain('0') + expect(wrapper.text()).toContain('A $0.00') + expect(wrapper.text()).toContain('U $0.00') + }) }) From 6d11f9ed77837968ba35188ebdc980ef60740e50 Mon Sep 17 00:00:00 2001 From: Oliver Date: Sat, 25 Apr 2026 20:39:58 -0400 Subject: [PATCH 04/46] Add Vertex service account support --- backend/cmd/server/wire_gen.go | 4 +- backend/internal/domain/constants.go | 11 +- .../internal/handler/admin/account_handler.go | 4 +- .../internal/service/account_test_service.go | 101 +++++- .../internal/service/claude_token_provider.go | 48 ++- .../service/claude_token_provider_test.go | 8 +- backend/internal/service/domain_constants.go | 11 +- ...y_anthropic_vertex_service_account_test.go | 68 ++++ backend/internal/service/gateway_service.go | 88 ++++- .../service/gemini_messages_compat_service.go | 59 +++- .../internal/service/gemini_token_provider.go | 53 ++- .../service/vertex_service_account.go | 303 +++++++++++++++++ .../service/vertex_service_account_test.go | 77 +++++ .../components/account/CreateAccountModal.vue | 310 +++++++++++++++++- .../components/account/EditAccountModal.vue | 129 ++++++++ .../components/common/PlatformTypeBadge.vue | 3 + frontend/src/types/index.ts | 2 +- 17 files changed, 1243 insertions(+), 36 deletions(-) create mode 100644 backend/internal/service/gateway_anthropic_vertex_service_account_test.go create mode 100644 backend/internal/service/vertex_service_account.go create mode 100644 backend/internal/service/vertex_service_account_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f767bbea..dea46561 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -145,13 +145,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) + claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) @@ -178,7 +179,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index a57f7067..27c543dd 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -26,11 +26,12 @@ const ( // Account type constants const ( - AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = "apikey" // API Key类型账号 - AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) - AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI) ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 7454451a..e69e056f 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -98,7 +98,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -117,7 +117,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index c0bbc6dc..aa657e0e 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -64,6 +64,7 @@ func isOpenAIImageModel(model string) bool { type AccountTestService struct { accountRepo AccountRepository geminiTokenProvider *GeminiTokenProvider + claudeTokenProvider *ClaudeTokenProvider antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config @@ -74,6 +75,7 @@ type AccountTestService struct { func NewAccountTestService( accountRepo AccountRepository, geminiTokenProvider *GeminiTokenProvider, + claudeTokenProvider *ClaudeTokenProvider, antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, cfg *config.Config, @@ -82,6 +84,7 @@ func NewAccountTestService( return &AccountTestService{ accountRepo: accountRepo, geminiTokenProvider: geminiTokenProvider, + claudeTokenProvider: claudeTokenProvider, antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, @@ -210,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if account.IsBedrock() { return s.testBedrockAccountConnection(c, ctx, account, testModelID) } + if account.Type == AccountTypeServiceAccount { + return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID) + } // Determine authentication method and API URL var authToken string @@ -313,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.processClaudeStream(c, resp.Body) } +func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { + if mappedModel, matched := account.ResolveMappedModel(testModelID); matched { + testModelID = mappedModel + } else { + testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID)) + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + payload, err := createTestPayload(testModelID) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create test payload") + } + payloadBytes, _ := json.Marshal(payload) + vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error())) + } + + if s.claudeTokenProvider == nil { + return s.sendErrorAndEnd(c, "Claude token provider not configured") + } + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error())) + } + + fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error())) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)) + if resp.StatusCode == http.StatusForbidden { + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } + return s.sendErrorAndEnd(c, errMsg) + } + + return s.processClaudeStream(c, resp.Body) +} + // testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { region := bedrockRuntimeRegion(account) @@ -711,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account testModelID = geminicli.DefaultTestModel } - // For API Key accounts with model mapping, map the model - if account.Type == AccountTypeAPIKey { + // For static upstream credentials with model mapping, map the model + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mapping := account.GetModelMapping() if len(mapping) > 0 { if mappedModel, exists := mapping[testModelID]; exists { @@ -740,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) case AccountTypeOAuth: req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) + case AccountTypeServiceAccount: + req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload) default: return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -893,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) } +func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { + if s.geminiTokenProvider == nil { + return nil, fmt.Errorf("gemini token provider not configured") + } + accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("failed to get service account access token: %w", err) + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + return req, nil +} + // buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { var inner map[string]any diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index 82fa31c4..9292979f 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -17,7 +17,7 @@ const ( // ClaudeTokenCache token cache interface. type ClaudeTokenCache = GeminiTokenCache -// ClaudeTokenProvider manages access_token for Claude OAuth accounts. +// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts. type ClaudeTokenProvider struct { accountRepo AccountRepository tokenCache ClaudeTokenCache @@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { - return "", errors.New("not an anthropic oauth account") + if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) { + return "", errors.New("not an anthropic oauth or service account") + } + if account.Type == AccountTypeServiceAccount { + return p.getServiceAccountAccessToken(ctx, account) } cacheKey := ClaudeTokenCacheKey(account) @@ -157,3 +160,42 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { + key, err := parseVertexServiceAccountKey(account) + if err != nil { + return "", err + } + cacheKey := vertexServiceAccountCacheKey(account, key) + + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + locked := false + if p.tokenCache != nil { + var lockErr error + locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { + time.Sleep(claudeLockWaitTime) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + } + + accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) + if err != nil { + return "", err + } + if p.tokenCache != nil { + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + return accessToken, nil +} diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go index 3e21f6f4..d4a4a14a 100644 --- a/backend/internal/service/claude_token_provider_test.go +++ b/backend/internal/service/claude_token_provider_test.go @@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A return "", errors.New("account is nil") } if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { - return "", errors.New("not an anthropic oauth account") + return "", errors.New("not an anthropic oauth or service account") } cacheKey := ClaudeTokenCacheKey(account) @@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } @@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } @@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 04037987..e3d3a872 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -36,11 +36,12 @@ const ( // Account type constants const ( - AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 - AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) - AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI) ) // Redeem type constants diff --git a/backend/internal/service/gateway_anthropic_vertex_service_account_test.go b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go new file mode 100644 index 00000000..aa779805 --- /dev/null +++ b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go @@ -0,0 +1,68 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("Anthropic-Version", "2023-06-01") + c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + + account := &Account{ + ID: 301, + Platform: PlatformAnthropic, + Type: AccountTypeServiceAccount, + Credentials: map[string]any{ + "project_id": "vertex-proj", + "location": "us-east5", + }, + } + body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`) + + svc := &GatewayService{} + req, err := svc.buildUpstreamRequest( + context.Background(), + c, + account, + body, + "vertex-token", + "service_account", + "claude-sonnet-4-5@20250929", + false, + false, + ) + require.NoError(t, err) + require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String()) + require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization")) + require.Empty(t, getHeaderRaw(req.Header, "x-api-key")) + require.Empty(t, getHeaderRaw(req.Header, "anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta")) + + got := readRequestBodyForTest(t, req) + require.Equal(t, "", gjson.GetBytes(got, "model").String()) + require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String()) + require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String()) +} + +func readRequestBodyForTest(t *testing.T, req *http.Request) []byte { + t.Helper() + require.NotNil(t, req.Body) + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + return body +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6be19ba6..75725753 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3597,7 +3597,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { - requestedModel = claude.NormalizeModelID(requestedModel) + if account.Type == AccountTypeServiceAccount { + requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel)) + } else { + requestedModel = claude.NormalizeModelID(requestedModel) + } } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) @@ -3617,6 +3621,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( return apiKey, "apikey", nil case AccountTypeBedrock: return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 + case AccountTypeServiceAccount: + if account.Platform != PlatformAnthropic { + return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform) + } + if s.claudeTokenProvider == nil { + return "", "", errors.New("claude token provider not configured") + } + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "service_account", nil default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -4219,6 +4235,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + if candidate, matched := account.ResolveMappedModel(reqModel); matched { + mappedModel = candidate + mappingSource = "account" + } else { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel)) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "vertex" + } + } + } if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(reqModel) if normalized != reqModel { @@ -5688,6 +5716,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse( } func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { + if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream) + } + // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { @@ -5874,6 +5906,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex return req, nil } +func (s *GatewayService) buildUpstreamRequestAnthropicVertex( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, + modelID string, + reqStream bool, +) (*http.Request, error) { + vertexBody, err := buildVertexAnthropicRequestBody(body) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, vertexBody) + fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" { + continue + } + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Del("anthropic-version") + setHeaderRaw(req.Header, "authorization", "Bearer "+token) + setHeaderRaw(req.Header, "content-type", "application/json") + + s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{ + "url": req.URL.String(), + "token_type": "service_account", + "model": modelID, + "stream": strconv.FormatBool(reqStream), + }) + + return req, nil +} + // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 7a24071b..20293ac8 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -579,7 +579,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex originalModel := req.Model mappedModel := req.Model - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(req.Model) } @@ -712,6 +712,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } requestIDHeader = "x-request-id" + case AccountTypeServiceAccount: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream) + if err != nil { + return nil, "", err + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + default: return nil, fmt.Errorf("unsupported account type: %s", account.Type) } @@ -1094,7 +1124,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. body = ensureGeminiFunctionCallThoughtSignatures(body) mappedModel := originalModel - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } @@ -1213,6 +1243,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } requestIDHeader = "x-request-id" + case AccountTypeServiceAccount: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream) + if err != nil { + return nil, "", err + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + default: return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type) } diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 7add3460..c22f2131 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -15,7 +15,7 @@ const ( geminiTokenCacheSkew = 5 * time.Minute ) -// GeminiTokenProvider manages access_token for Gemini OAuth accounts. +// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts. type GeminiTokenProvider struct { accountRepo AccountRepository tokenCache GeminiTokenCache @@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { - return "", errors.New("not a gemini oauth account") + if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) { + return "", errors.New("not a gemini oauth or service account") + } + if account.Type == AccountTypeServiceAccount { + return p.getServiceAccountAccessToken(ctx, account) } cacheKey := GeminiTokenCacheKey(account) @@ -168,7 +171,51 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } +func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { + key, err := parseVertexServiceAccountKey(account) + if err != nil { + return "", err + } + cacheKey := vertexServiceAccountCacheKey(account, key) + + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + locked := false + if p.tokenCache != nil { + var lockErr error + locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { + time.Sleep(200 * time.Millisecond) + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + } + + accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) + if err != nil { + return "", err + } + if p.tokenCache != nil { + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + return accessToken, nil +} + func GeminiTokenCacheKey(account *Account) string { + if account != nil && account.Type == AccountTypeServiceAccount { + if key, err := parseVertexServiceAccountKey(account); err == nil { + return vertexServiceAccountCacheKey(account, key) + } + } projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return "gemini:" + projectID diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go new file mode 100644 index 00000000..d4130b93 --- /dev/null +++ b/backend/internal/service/vertex_service_account.go @@ -0,0 +1,303 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + vertexDefaultLocation = "us-central1" + vertexDefaultTokenURL = "https://oauth2.googleapis.com/token" + vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" + vertexServiceAccountCacheSkew = 5 * time.Minute + vertexAnthropicVersion = "vertex-2023-10-16" +) + +var ( + vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`) + vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`) + vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`) +) + +type vertexServiceAccountKey struct { + Type string `json:"type"` + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + TokenURI string `json:"token_uri"` +} + +type vertexTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Error string `json:"error"` + ErrorDesc string `json:"error_description"` +} + +func (a *Account) IsVertexServiceAccount() bool { + return a != nil && a.Type == AccountTypeServiceAccount +} + +func (a *Account) VertexProjectID() string { + if a == nil { + return "" + } + if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" { + return v + } + key, err := parseVertexServiceAccountKey(a) + if err == nil { + return strings.TrimSpace(key.ProjectID) + } + return "" +} + +func (a *Account) VertexLocation(model string) string { + if a == nil { + return vertexDefaultLocation + } + if model != "" && a.Credentials != nil { + if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok { + if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" { + return strings.TrimSpace(loc) + } + } + } + if v := strings.TrimSpace(a.GetCredential("location")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" { + return v + } + return vertexDefaultLocation +} + +func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) { + if account == nil || account.Credentials == nil { + return nil, errors.New("service account credentials not configured") + } + + if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" { + return parseVertexServiceAccountJSON([]byte(raw)) + } + if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" { + return parseVertexServiceAccountJSON([]byte(raw)) + } + if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok { + b, _ := json.Marshal(nested) + return parseVertexServiceAccountJSON(b) + } + if nested, ok := account.Credentials["service_account"].(map[string]any); ok { + b, _ := json.Marshal(nested) + return parseVertexServiceAccountJSON(b) + } + return nil, errors.New("service_account_json not found in credentials") +} + +func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) { + var key vertexServiceAccountKey + if err := json.Unmarshal(raw, &key); err != nil { + return nil, fmt.Errorf("invalid service account json: %w", err) + } + if strings.TrimSpace(key.ClientEmail) == "" { + return nil, errors.New("service account json missing client_email") + } + if strings.TrimSpace(key.PrivateKey) == "" { + return nil, errors.New("service account json missing private_key") + } + if strings.TrimSpace(key.ProjectID) == "" { + return nil, errors.New("service account json missing project_id") + } + if strings.TrimSpace(key.TokenURI) == "" { + key.TokenURI = vertexDefaultTokenURL + } + return &key, nil +} + +func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string { + fingerprint := "" + if key != nil { + sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID)) + fingerprint = hex.EncodeToString(sum[:8]) + } + if fingerprint == "" && account != nil { + fingerprint = fmt.Sprintf("account:%d", account.ID) + } + return "vertex:service_account:" + fingerprint +} + +func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) { + now := time.Now() + claims := jwt.MapClaims{ + "iss": key.ClientEmail, + "scope": vertexCloudPlatformScope, + "aud": key.TokenURI, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + if strings.TrimSpace(key.PrivateKeyID) != "" { + token.Header["kid"] = key.PrivateKeyID + } + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey)) + if err != nil { + return "", 0, fmt.Errorf("parse service account private key: %w", err) + } + assertion, err := token.SignedString(privateKey) + if err != nil { + return "", 0, fmt.Errorf("sign service account assertion: %w", err) + } + + values := url.Values{} + values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + values.Set("assertion", assertion) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode())) + if err != nil { + return "", 0, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", 0, fmt.Errorf("service account token request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + var parsed vertexTokenResponse + _ = json.Unmarshal(body, &parsed) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg := strings.TrimSpace(parsed.ErrorDesc) + if msg == "" { + msg = strings.TrimSpace(parsed.Error) + } + if msg == "" { + msg = string(bytes.TrimSpace(body)) + } + return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg) + } + if strings.TrimSpace(parsed.AccessToken) == "" { + return "", 0, errors.New("service account token response missing access_token") + } + ttl := time.Duration(parsed.ExpiresIn) * time.Second + if ttl <= 0 { + ttl = time.Hour + } + if ttl > vertexServiceAccountCacheSkew { + ttl -= vertexServiceAccountCacheSkew + } + return parsed.AccessToken, ttl, nil +} + +func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) { + projectID = strings.TrimSpace(projectID) + location = strings.TrimSpace(location) + model = strings.TrimSpace(model) + action = strings.TrimSpace(action) + if projectID == "" { + return "", errors.New("vertex project_id is required") + } + if location == "" { + location = vertexDefaultLocation + } + if !vertexLocationPattern.MatchString(location) { + return "", fmt.Errorf("invalid vertex location: %s", location) + } + if model == "" { + return "", errors.New("vertex model is required") + } + switch action { + case "generateContent", "streamGenerateContent", "countTokens": + default: + return "", fmt.Errorf("unsupported vertex gemini action: %s", action) + } + host := fmt.Sprintf("%s-aiplatform.googleapis.com", location) + if location == "global" { + host = "aiplatform.googleapis.com" + } + u := fmt.Sprintf( + "https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + host, + url.PathEscape(projectID), + url.PathEscape(location), + url.PathEscape(model), + action, + ) + if stream { + u += "?alt=sse" + } + return u, nil +} + +func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) { + projectID = strings.TrimSpace(projectID) + location = strings.TrimSpace(location) + model = strings.TrimSpace(model) + if projectID == "" { + return "", errors.New("vertex project_id is required") + } + if location == "" { + location = vertexDefaultLocation + } + if !vertexLocationPattern.MatchString(location) { + return "", fmt.Errorf("invalid vertex location: %s", location) + } + if model == "" { + return "", errors.New("vertex model is required") + } + action := "rawPredict" + if stream { + action = "streamRawPredict" + } + host := fmt.Sprintf("%s-aiplatform.googleapis.com", location) + if location == "global" { + host = "aiplatform.googleapis.com" + } + escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@") + return fmt.Sprintf( + "https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", + host, + url.PathEscape(projectID), + url.PathEscape(location), + escapedModel, + action, + ), nil +} + +func normalizeVertexAnthropicModelID(model string) string { + model = strings.TrimSpace(model) + if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) { + return model + } + if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 { + return m[1] + "@" + m[2] + } + return model +} + +func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("parse anthropic vertex request body: %w", err) + } + delete(payload, "model") + payload["anthropic_version"] = vertexAnthropicVersion + return json.Marshal(payload) +} diff --git a/backend/internal/service/vertex_service_account_test.go b/backend/internal/service/vertex_service_account_test.go new file mode 100644 index 00000000..519f5b2f --- /dev/null +++ b/backend/internal/service/vertex_service_account_test.go @@ -0,0 +1,77 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildVertexGeminiURL(t *testing.T) { + got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true) + require.NoError(t, err) + require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got) +} + +func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) { + got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true) + require.NoError(t, err) + require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got) +} + +func TestBuildVertexAnthropicURL(t *testing.T) { + got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false) + require.NoError(t, err) + require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got) +} + +func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) { + got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true) + require.NoError(t, err) + require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got) +} + +func TestNormalizeVertexAnthropicModelID(t *testing.T) { + require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929")) + require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929")) + require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6")) +} + +func TestBuildVertexAnthropicRequestBody(t *testing.T) { + got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)) + require.NoError(t, err) + require.Equal(t, "", gjson.GetBytes(got, "model").String()) + require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String()) + require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int()) + require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String()) +} + +func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) { + _, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vertex location") +} + +func TestParseVertexServiceAccountKey(t *testing.T) { + raw := `{ + "type": "service_account", + "project_id": "vertex-proj", + "private_key_id": "kid", + "private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n", + "client_email": "svc@vertex-proj.iam.gserviceaccount.com" + }` + account := &Account{ + Type: AccountTypeServiceAccount, + Platform: PlatformGemini, + Credentials: map[string]any{ + "service_account_json": raw, + }, + } + key, err := parseVertexServiceAccountKey(account) + require.NoError(t, err) + require.Equal(t, "vertex-proj", key.ProjectID) + require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail) + require.Equal(t, vertexDefaultTokenURL, key.TokenURI) + require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY")) +} diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 96673f8f..e7a790ec 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -153,7 +153,7 @@
-
+
+ + +
+ +
+

使用 Google Cloud Service Account JSON 通过 Vertex AI 调用 Anthropic Claude。建议配置模型映射,将客户端 Claude 模型名映射到 Vertex 模型 ID。

@@ -302,6 +335,7 @@ {{ t('admin.accounts.types.responsesApi') }}
+
@@ -320,7 +354,7 @@ {{ t('admin.accounts.gemini.helpButton') }}
-
+
+ +
+
+

使用 Google Cloud Service Account JSON 访问 Vertex AI Gemini。建议将 Vertex 账号放入独立分组,避免和 AI Studio/Gemini OAuth 同模型混调。

+
+
@@ -610,7 +681,7 @@
-
+
+
+
+
+
+ + {{ vertexClientEmail ? '已读取 Service Account JSON' : '拖入 Service Account JSON' }} +
+

+ {{ vertexClientEmail ? '密钥内容不会在表单中显示。' : '把 .json 文件拖到这里,或点击按钮选择文件。' }} +

+
+ +
+
+
Project ID: {{ vertexProjectId }}
+
Client Email: {{ vertexClientEmail }}
+
+
+

上传或拖入 JSON 后会自动读取 project_id,密钥内容仅用于创建账号提交。

+
+ +
+
+ + +
+
+ + +

不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。

+
+
+
+
@@ -3085,7 +3246,7 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') @@ -3151,6 +3312,58 @@ const bedrockSessionToken = ref('') const bedrockRegion = ref('us-east-1') const bedrockForceGlobal = ref(false) const bedrockApiKeyValue = ref('') +const vertexServiceAccountFileInput = ref(null) +const vertexServiceAccountJson = ref('') +const vertexProjectId = ref('') +const vertexClientEmail = ref('') +const vertexLocation = ref('global') +const vertexServiceAccountDragActive = ref(false) +const vertexLocationOptions = [ + { + label: 'Common', + options: [ + { value: 'us-central1', label: 'us-central1 (Iowa)' }, + { value: 'global', label: 'global' }, + { value: 'us', label: 'us' }, + { value: 'eu', label: 'eu' } + ] + }, + { + label: 'United States', + options: [ + { value: 'us-east1', label: 'us-east1 (South Carolina)' }, + { value: 'us-east4', label: 'us-east4 (Northern Virginia)' }, + { value: 'us-east5', label: 'us-east5 (Columbus)' }, + { value: 'us-south1', label: 'us-south1 (Dallas)' }, + { value: 'us-west1', label: 'us-west1 (Oregon)' }, + { value: 'us-west4', label: 'us-west4 (Las Vegas)' } + ] + }, + { + label: 'Europe', + options: [ + { value: 'europe-west1', label: 'europe-west1 (Belgium)' }, + { value: 'europe-west2', label: 'europe-west2 (London)' }, + { value: 'europe-west3', label: 'europe-west3 (Frankfurt)' }, + { value: 'europe-west4', label: 'europe-west4 (Netherlands)' }, + { value: 'europe-west6', label: 'europe-west6 (Zurich)' }, + { value: 'europe-west8', label: 'europe-west8 (Milan)' }, + { value: 'europe-west9', label: 'europe-west9 (Paris)' } + ] + }, + { + label: 'Asia Pacific', + options: [ + { value: 'asia-east1', label: 'asia-east1 (Taiwan)' }, + { value: 'asia-east2', label: 'asia-east2 (Hong Kong)' }, + { value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' }, + { value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' }, + { value: 'asia-south1', label: 'asia-south1 (Mumbai)' }, + { value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' }, + { value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' } + ] + } +] as const const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -3397,7 +3610,7 @@ watch( // Sync form.type based on accountCategory, addMethod, and platform-specific type watch( - [accountCategory, addMethod, antigravityAccountType], + [accountCategory, addMethod, antigravityAccountType, () => form.platform], ([category, method, agType]) => { // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { @@ -3409,7 +3622,9 @@ watch( form.type = 'bedrock' as AccountType return } - if (category === 'oauth-based') { + if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') { + form.type = 'service_account' as AccountType + } else if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { form.type = 'apikey' @@ -3447,6 +3662,12 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') { + accountCategory.value = 'oauth-based' + } + if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') { + accountCategory.value = 'oauth-based' + } // Reset Bedrock fields when switching platforms bedrockAccessKeyId.value = '' bedrockSecretAccessKey.value = '' @@ -3455,6 +3676,10 @@ watch( bedrockForceGlobal.value = false bedrockAuthMode.value = 'sigv4' bedrockApiKeyValue.value = '' + vertexServiceAccountJson.value = '' + vertexProjectId.value = '' + vertexClientEmail.value = '' + vertexLocation.value = 'global' // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -3886,6 +4111,10 @@ const resetForm = () => { antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' upstreamApiKey.value = '' + vertexServiceAccountJson.value = '' + vertexProjectId.value = '' + vertexClientEmail.value = '' + vertexLocation.value = 'global' tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -4009,6 +4238,52 @@ const normalizePoolModeRetryCount = (value: number) => { return normalized } +const applyVertexServiceAccountJson = (value: string) => { + const raw = value.trim() + if (!raw) { + vertexProjectId.value = '' + vertexClientEmail.value = '' + return false + } + try { + const parsed = JSON.parse(raw) as Record + const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : '' + const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : '' + const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : '' + if (!projectId || !clientEmail || !privateKey) { + appStore.showError('Service Account JSON 缺少 project_id、client_email 或 private_key') + return false + } + vertexProjectId.value = projectId + vertexClientEmail.value = clientEmail + vertexServiceAccountJson.value = JSON.stringify(parsed) + return true + } catch { + appStore.showError('Service Account JSON 格式无效') + return false + } +} + +const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value) + +const handleVertexServiceAccountFile = async (event: Event) => { + const input = event.target as HTMLInputElement + const file = input.files?.[0] + if (!file) return + try { + applyVertexServiceAccountJson(await file.text()) + } finally { + input.value = '' + } +} + +const handleVertexServiceAccountDrop = async (event: DragEvent) => { + vertexServiceAccountDragActive.value = false + const file = event.dataTransfer?.files?.[0] + if (!file) return + applyVertexServiceAccountJson(await file.text()) +} + const handleSubmit = async () => { // For OAuth-based type, handle OAuth flow (goes to step 2) if (isOAuthFlow.value) { @@ -4122,6 +4397,29 @@ const handleSubmit = async () => { return } + if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!parseVertexServiceAccountJson()) { + return + } + if (!vertexLocation.value.trim()) { + appStore.showError('请填写 Vertex location') + return + } + const credentials: Record = { + service_account_json: vertexServiceAccountJson.value.trim(), + project_id: vertexProjectId.value.trim(), + client_email: vertexClientEmail.value.trim(), + location: vertexLocation.value.trim(), + tier_id: 'vertex' + } + await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials) + return + } + // For apikey type, create directly if (!apiKeyValue.value.trim()) { appStore.showError(t('admin.accounts.pleaseEnterApiKey')) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 42211ba7..69e2186b 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -567,6 +567,46 @@
+ +
+
+
+ + +

Service Account JSON 不在编辑页显示;需要更换 JSON 时请删除账号后重新创建。

+
+
+ + +

不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。

+
+
+
+
@@ -1987,6 +2027,55 @@ const editBedrockSessionToken = ref('') const editBedrockRegion = ref('') const editBedrockForceGlobal = ref(false) const editBedrockApiKeyValue = ref('') +const editVertexProjectId = ref('') +const editVertexClientEmail = ref('') +const editVertexLocation = ref('us-central1') +const vertexLocationOptions = [ + { + label: 'Common', + options: [ + { value: 'us-central1', label: 'us-central1 (Iowa)' }, + { value: 'global', label: 'global' }, + { value: 'us', label: 'us' }, + { value: 'eu', label: 'eu' } + ] + }, + { + label: 'United States', + options: [ + { value: 'us-east1', label: 'us-east1 (South Carolina)' }, + { value: 'us-east4', label: 'us-east4 (Northern Virginia)' }, + { value: 'us-east5', label: 'us-east5 (Columbus)' }, + { value: 'us-south1', label: 'us-south1 (Dallas)' }, + { value: 'us-west1', label: 'us-west1 (Oregon)' }, + { value: 'us-west4', label: 'us-west4 (Las Vegas)' } + ] + }, + { + label: 'Europe', + options: [ + { value: 'europe-west1', label: 'europe-west1 (Belgium)' }, + { value: 'europe-west2', label: 'europe-west2 (London)' }, + { value: 'europe-west3', label: 'europe-west3 (Frankfurt)' }, + { value: 'europe-west4', label: 'europe-west4 (Netherlands)' }, + { value: 'europe-west6', label: 'europe-west6 (Zurich)' }, + { value: 'europe-west8', label: 'europe-west8 (Milan)' }, + { value: 'europe-west9', label: 'europe-west9 (Paris)' } + ] + }, + { + label: 'Asia Pacific', + options: [ + { value: 'asia-east1', label: 'asia-east1 (Taiwan)' }, + { value: 'asia-east2', label: 'asia-east2 (Hong Kong)' }, + { value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' }, + { value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' }, + { value: 'asia-south1', label: 'asia-south1 (Mumbai)' }, + { value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' }, + { value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' } + ] + } +] as const const isBedrockAPIKeyMode = computed(() => props.account?.type === 'bedrock' && (props.account?.credentials as Record)?.auth_mode === 'apikey' @@ -2246,6 +2335,9 @@ const syncFormFromAccount = (newAccount: Account | null) => { const credentials = newAccount.credentials as Record | undefined interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true + editVertexProjectId.value = '' + editVertexClientEmail.value = '' + editVertexLocation.value = 'us-central1' // Load mixed scheduling setting (only for antigravity accounts) mixedScheduling.value = false @@ -2467,6 +2559,11 @@ const syncFormFromAccount = (newAccount: Account | null) => { } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' + } else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editVertexProjectId.value = (credentials.project_id as string) || '' + editVertexClientEmail.value = (credentials.client_email as string) || '' + editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1' } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -3057,6 +3154,38 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + if (!editVertexProjectId.value.trim()) { + appStore.showError('Service Account JSON 缺少 project_id') + return + } + if (!editVertexClientEmail.value.trim()) { + appStore.showError('Service Account JSON 缺少 client_email') + return + } + if (!editVertexLocation.value.trim()) { + appStore.showError('请填写 Vertex location') + return + } + + if (!currentCredentials.service_account_json && !currentCredentials.service_account) { + appStore.showError('请上传 Service Account JSON') + return + } + newCredentials.project_id = editVertexProjectId.value.trim() + newCredentials.client_email = editVertexClientEmail.value.trim() + newCredentials.location = editVertexLocation.value.trim() + newCredentials.tier_id = 'vertex' + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else if (props.account.type === 'bedrock') { const currentCredentials = (props.account.credentials as Record) || {} diff --git a/frontend/src/components/common/PlatformTypeBadge.vue b/frontend/src/components/common/PlatformTypeBadge.vue index 1ebc8892..1c7b08c0 100644 --- a/frontend/src/components/common/PlatformTypeBadge.vue +++ b/frontend/src/components/common/PlatformTypeBadge.vue @@ -25,6 +25,7 @@ + {{ typeLabel }} @@ -88,6 +89,8 @@ const typeLabel = computed(() => { return 'Key' case 'bedrock': return 'AWS' + case 'service_account': + return 'Vertex' default: return props.type } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 2a15ad00..80789011 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -641,7 +641,7 @@ export interface UpdateGroupRequest { // ==================== Account & Proxy Types ==================== export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' -export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' +export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'service_account' export type OAuthAddMethod = 'oauth' | 'setup-token' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' From 9b6dcc57bda7daf87dc4b0552cd0267a2a163782 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 26 Apr 2026 12:31:52 +0800 Subject: [PATCH 05/46] =?UTF-8?q?feat(affiliate):=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E9=82=80=E8=AF=B7=E8=BF=94=E5=88=A9=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复返利不到账的根因:tryClaimAffiliateRebateAudit 中 PostgreSQL 参数类型推断冲突 - 补全 OAuth 注册路径(LinuxDo/OIDC/WeChat/Pending Flow)的邀请码绑定 - 前端 OAuth 注册页面传递 aff_code 参数 - 新增返利冻结期机制:可配置冻结时间,到期后自动解冻(懒解冻) - 新增返利有效期:绑定后 N 天内有效,过期不再产生返利 - 新增单人返利上限:超出上限部分精确截断 - 增强返利流程 slog 结构化日志,便于排查问题 - 已邀请用户列表增加返利明细列 --- .gitignore | 1 + .../internal/handler/admin/setting_handler.go | 48 +++++++ .../internal/handler/auth_linuxdo_oauth.go | 3 +- .../handler/auth_oauth_pending_flow.go | 2 + backend/internal/handler/auth_oidc_oauth.go | 3 +- backend/internal/handler/auth_wechat_oauth.go | 3 +- backend/internal/handler/dto/settings.go | 13 +- backend/internal/repository/affiliate_repo.go | 116 +++++++++++++-- .../affiliate_repo_integration_test.go | 2 +- backend/internal/server/api_contract_test.go | 6 + backend/internal/service/affiliate_service.go | 54 ++++++- .../internal/service/auth_oauth_email_flow.go | 2 + backend/internal/service/auth_service.go | 21 ++- .../service/auth_service_register_test.go | 4 +- backend/internal/service/domain_constants.go | 16 ++- .../internal/service/payment_fulfillment.go | 35 +++-- backend/internal/service/setting_service.go | 84 +++++++++++ backend/internal/service/settings_view.go | 15 +- .../133_affiliate_rebate_freeze.sql | 17 +++ .../api/__tests__/auth-oauth-adoption.spec.ts | 40 ++++++ frontend/src/api/admin/settings.ts | 6 + frontend/src/api/auth.ts | 35 +++-- .../components/auth/LinuxDoOAuthSection.vue | 5 +- .../src/components/auth/OidcOAuthSection.vue | 3 + .../components/auth/WechatOAuthSection.vue | 3 + frontend/src/i18n/locales/en.ts | 12 +- frontend/src/i18n/locales/zh.ts | 12 +- frontend/src/types/index.ts | 2 + .../utils/__tests__/oauthAffiliate.spec.ts | 48 +++++++ frontend/src/utils/oauthAffiliate.ts | 133 ++++++++++++++++++ frontend/src/views/admin/SettingsView.vue | 56 ++++++++ frontend/src/views/auth/EmailVerifyView.vue | 9 +- .../src/views/auth/LinuxDoCallbackView.vue | 22 ++- frontend/src/views/auth/LoginView.vue | 3 + frontend/src/views/auth/OidcCallbackView.vue | 22 ++- frontend/src/views/auth/RegisterView.vue | 40 +++++- .../src/views/auth/WechatCallbackView.vue | 22 ++- .../auth/__tests__/EmailVerifyView.spec.ts | 3 + .../__tests__/LinuxDoCallbackView.spec.ts | 1 + .../auth/__tests__/OidcCallbackView.spec.ts | 1 + .../auth/__tests__/WechatCallbackView.spec.ts | 1 + frontend/src/views/user/AffiliateView.vue | 32 +++-- 42 files changed, 852 insertions(+), 104 deletions(-) create mode 100644 backend/migrations/133_affiliate_rebate_freeze.sql create mode 100644 frontend/src/utils/__tests__/oauthAffiliate.spec.ts create mode 100644 frontend/src/utils/oauthAffiliate.ts diff --git a/.gitignore b/.gitignore index bf7ee064..a61f406d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ docs/claude-relay-service/ +.codex # =================== # Go 后端 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 40bf1c69..320dbd6b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, AffiliateRebateRate: settings.AffiliateRebateRate, + AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours, + AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap, DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, @@ -342,6 +345,9 @@ type UpdateSettingsRequest struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` + AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"` + AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"` + AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` @@ -485,6 +491,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if affiliateRebateRate > service.AffiliateRebateRateMax { affiliateRebateRate = service.AffiliateRebateRateMax } + affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours + if req.AffiliateRebateFreezeHours != nil { + affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours + } + if affiliateRebateFreezeHours < 0 { + affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault + } + if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax { + affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax + } + affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays + if req.AffiliateRebateDurationDays != nil { + affiliateRebateDurationDays = *req.AffiliateRebateDurationDays + } + if affiliateRebateDurationDays < 0 { + affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault + } + if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax { + affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax + } + affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap + if req.AffiliateRebatePerInviteeCap != nil { + affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap + } + if affiliateRebatePerInviteeCap < 0 { + affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault + } // 通用表格配置:兼容旧客户端未传字段时保留当前值。 if req.TableDefaultPageSize <= 0 { req.TableDefaultPageSize = previousSettings.TableDefaultPageSize @@ -1137,6 +1170,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, AffiliateRebateRate: affiliateRebateRate, + AffiliateRebateFreezeHours: affiliateRebateFreezeHours, + AffiliateRebateDurationDays: affiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap, DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, @@ -1458,6 +1494,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, AffiliateRebateRate: updatedSettings.AffiliateRebateRate, + AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours, + AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap, DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -1768,6 +1807,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AffiliateRebateRate != after.AffiliateRebateRate { changed = append(changed, "affiliate_rebate_rate") } + if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours { + changed = append(changed, "affiliate_rebate_freeze_hours") + } + if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays { + changed = append(changed, "affiliate_rebate_duration_days") + } + if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap { + changed = append(changed, "affiliate_rebate_per_invitee_cap") + } if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { changed = append(changed, "default_subscriptions") } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2ef05963..7df4abfd 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( type completeLinuxDoOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 604ad903..490afd0f 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct { VerifyCode string `json:"verify_code,omitempty"` Password string `json:"password" binding:"required,min=6"` InvitationCode string `json:"invitation_code,omitempty"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) user, strings.TrimSpace(req.InvitationCode), strings.TrimSpace(session.ProviderType), + strings.TrimSpace(req.AffCode), ); err != nil { _ = tx.Rollback() if rollbackCreatedUser(err) { diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0ac8871b..4264002d 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( type completeOIDCOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index efee4cc0..34e70ed0 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService type completeWeChatOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 051fab18..92ae4dc6 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -106,11 +106,14 @@ type SystemSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` - DefaultUserRPMLimit int `json:"default_user_rpm_limit"` - DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` + AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"` + AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"` + AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` + DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index e3dd56b8..ef89e5b6 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID return bound, nil } -func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) { +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { if amount <= 0 { return false, nil } var applied bool err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { - res, err := txClient.ExecContext(txCtx, - "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2", - amount, inviterID, - ) + // freezeHours > 0: add to frozen quota; == 0: add to available quota directly + var updateSQL string + if freezeHours > 0 { + updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2" + } else { + updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2" + } + res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID) if err != nil { return err } @@ -106,10 +110,19 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite return nil } - if _, err = txClient.ExecContext(txCtx, ` + if freezeHours > 0 { + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, + inviterID, amount, inviteeUserID, freezeHours); err != nil { + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } + } else { + if _, err = txClient.ExecContext(txCtx, ` INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { - return fmt.Errorf("insert affiliate accrue ledger: %w", err) + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } } applied = true @@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); return applied, nil } +func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) { + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, + `SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`, + inviterID, inviteeUserID) + if err != nil { + return 0, fmt.Errorf("query accrued rebate from invitee: %w", err) + } + defer func() { _ = rows.Close() }() + var total float64 + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, rows.Close() +} + +func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) { + var thawed float64 + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + var err error + thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID) + return err + }) + return thawed, err +} + +// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx. +func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) { + rows, err := txClient.QueryContext(txCtx, ` +WITH matured AS ( + UPDATE user_affiliate_ledger + SET frozen_until = NULL, updated_at = NOW() + WHERE user_id = $1 + AND frozen_until IS NOT NULL + AND frozen_until <= NOW() + RETURNING amount +) +SELECT COALESCE(SUM(amount), 0) FROM matured`, userID) + if err != nil { + return 0, fmt.Errorf("thaw frozen quota: %w", err) + } + defer func() { _ = rows.Close() }() + + var thawed float64 + if rows.Next() { + if err := rows.Scan(&thawed); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if thawed <= 0 { + return 0, nil + } + + _, err = txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_quota = aff_quota + $1, + aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0), + updated_at = NOW() +WHERE user_id = $2`, thawed, userID) + if err != nil { + return 0, fmt.Errorf("move thawed quota: %w", err) + } + return thawed, nil +} + func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) { var transferred float64 var newBalance float64 @@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID return err } + // Thaw any matured frozen quota before transfer. + if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil { + return fmt.Errorf("thaw before transfer: %w", err) + } + rows, err := txClient.QueryContext(txCtx, ` WITH claimed AS ( SELECT aff_quota::double precision AS amount @@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, SELECT ua.user_id, COALESCE(u.email, ''), COALESCE(u.username, ''), - ua.created_at + ua.created_at, + COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate FROM user_affiliates ua LEFT JOIN users u ON u.id = ua.user_id +LEFT JOIN user_affiliate_ledger ual + ON ual.user_id = $1 + AND ual.source_user_id = ua.user_id + AND ual.action = 'accrue' WHERE ua.inviter_id = $1 +GROUP BY ua.user_id, u.email, u.username, ua.created_at ORDER BY ua.created_at DESC LIMIT $2`, inviterID, limit) if err != nil { @@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit) for rows.Next() { var item service.AffiliateInvitee var createdAt time.Time - if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil { + if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil { return nil, err } item.CreatedAt = &createdAt @@ -299,6 +393,7 @@ SELECT user_id, inviter_id, aff_count, aff_quota::double precision, + aff_frozen_quota::double precision, aff_history_quota::double precision, created_at, updated_at @@ -326,6 +421,7 @@ WHERE user_id = $1`, userID) &inviterID, &out.AffCount, &out.AffQuota, + &out.AffFrozenQuota, &out.AffHistoryQuota, &out.CreatedAt, &out.UpdatedAt, @@ -351,6 +447,7 @@ SELECT user_id, inviter_id, aff_count, aff_quota::double precision, + aff_frozen_quota::double precision, aff_history_quota::double precision, created_at, updated_at @@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code))) &inviterID, &out.AffCount, &out.AffQuota, + &out.AffFrozenQuota, &out.AffHistoryQuota, &out.CreatedAt, &out.UpdatedAt, diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go index 369f57cf..697a193b 100644 --- a/backend/internal/repository/affiliate_repo_integration_test.go +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { require.NoError(t, err) require.True(t, bound, "invitee must bind to inviter") - applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5) + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) require.NoError(t, err) require.True(t, applied, "AccrueQuota must report applied=true") diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 39286cbf..ca6fd0cc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) { "default_concurrency": 5, "default_balance": 1.25, "affiliate_rebate_rate": 20, + "affiliate_rebate_freeze_hours": 0, + "affiliate_rebate_duration_days": 0, + "affiliate_rebate_per_invitee_cap": 0, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -898,6 +901,9 @@ func TestAPIContracts(t *testing.T) { "default_concurrency": 0, "default_balance": 0, "affiliate_rebate_rate": 20, + "affiliate_rebate_freeze_hours": 0, + "affiliate_rebate_duration_days": 0, + "affiliate_rebate_per_invitee_cap": 0, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index aca32076..5a4e91e7 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -65,16 +65,18 @@ type AffiliateSummary struct { InviterID *int64 `json:"inviter_id,omitempty"` AffCount int `json:"aff_count"` AffQuota float64 `json:"aff_quota"` + AffFrozenQuota float64 `json:"aff_frozen_quota"` AffHistoryQuota float64 `json:"aff_history_quota"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type AffiliateInvitee struct { - UserID int64 `json:"user_id"` - Email string `json:"email"` - Username string `json:"username"` - CreatedAt *time.Time `json:"created_at,omitempty"` + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + CreatedAt *time.Time `json:"created_at,omitempty"` + TotalRebate float64 `json:"total_rebate"` } type AffiliateDetail struct { @@ -83,6 +85,7 @@ type AffiliateDetail struct { InviterID *int64 `json:"inviter_id,omitempty"` AffCount int `json:"aff_count"` AffQuota float64 `json:"aff_quota"` + AffFrozenQuota float64 `json:"aff_frozen_quota"` AffHistoryQuota float64 `json:"aff_history_quota"` // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例: // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。 @@ -95,7 +98,9 @@ type AffiliateRepository interface { EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) - AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) + GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) + ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error) @@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64 } func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) { + // Lazy thaw: move any matured frozen quota to available before reading. + if s != nil && s.repo != nil { + // best-effort: thaw failure is non-fatal + _, _ = s.repo.ThawFrozenQuota(ctx, userID) + } + summary, err := s.EnsureUserAffiliate(ctx, userID) if err != nil { return nil, err @@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) InviterID: summary.InviterID, AffCount: summary.AffCount, AffQuota: summary.AffQuota, + AffFrozenQuota: summary.AffFrozenQuota, AffHistoryQuota: summary.AffHistoryQuota, EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary), Invitees: invitees, @@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID if err != nil { return 0, err } + // 有效期检查:超过返利有效期后不再产生返利 + if s.settingService != nil { + if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 { + if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) { + return 0, nil + } + } + } + rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary) rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8) if rebate <= 0 { return 0, nil } - applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate) + // 单人上限检查:精确截断到剩余额度 + if s.settingService != nil { + if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 { + existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID) + if err != nil { + return 0, err + } + if existing >= perInviteeCap { + return 0, nil + } + if remaining := perInviteeCap - existing; rebate > remaining { + rebate = roundTo(remaining, 8) + } + } + } + + var freezeHours int + if s.settingService != nil { + freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) + } + + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) if err != nil { return 0, err } diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index a18cf39c..9815f31b 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( user *User, invitationCode string, signupSource string, + affiliateCode string, ) error { if s == nil || user == nil || user.ID <= 0 { return ErrServiceUnavailable @@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( s.updateOAuthSignupSource(ctx, user.ID, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) return nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 08b0f4b7..b1adf071 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 // invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { +// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource } } +// bindOAuthAffiliate initializes the affiliate profile and binds the inviter +// for an OAuth-registered user. Failures are logged but never block registration. +func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) { + if s.affiliateService == nil || userID <= 0 { + return + } + if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err) + } + if code := strings.TrimSpace(affiliateCode); code != "" { + if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err) + } + } +} + func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { if user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index c1ad6240..acc44a38 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "") require.NoError(t, err) require.NotNil(t, tokenPair) require.NotNil(t, user) @@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "") require.NoError(t, err) require.NotNil(t, tokenPair) require.Equal(t, existing.ID, user.ID) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 04037987..0ef4a486 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -20,10 +20,15 @@ const ( // Affiliate rebate settings const ( - AffiliateRebateRateDefault = 20.0 - AffiliateRebateRateMin = 0.0 - AffiliateRebateRateMax = 100.0 - AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 + AffiliateRebateRateDefault = 20.0 + AffiliateRebateRateMin = 0.0 + AffiliateRebateRateMax = 100.0 + AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 + AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容) + AffiliateRebateFreezeHoursMax = 720 // 最大 30 天 + AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效 + AffiliateRebateDurationDaysMax = 3650 // ~10 年 + AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限 ) // Platform constants @@ -97,6 +102,9 @@ const ( SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关 SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100) + SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结) + SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久) + SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index c6167447..5df69aea 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -269,7 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e switch action { case redeemActionSkipCompleted: - s.applyAffiliateRebateForOrder(ctx, o) + if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { + return err + } // Code already created and redeemed — just mark completed return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") case redeemActionCreate: @@ -283,7 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } - s.applyAffiliateRebateForOrder(ctx, o) + if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { + return err + } return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") } @@ -361,12 +365,12 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action return c > 0 } -func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) { +func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error { if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 { - return + return nil } if s.affiliateService == nil { - return + return nil } tx, err := s.entClient.Tx(ctx) @@ -374,7 +378,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("begin affiliate rebate tx: %v", err), }) - return + return fmt.Errorf("begin affiliate rebate tx: %w", err) } defer func() { _ = tx.Rollback() }() @@ -384,10 +388,10 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("claim affiliate rebate audit: %w", err) } if !claimed { - return + return nil } rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) @@ -395,7 +399,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("accrue affiliate rebate: %w", err) } if rebateAmount <= 0 { @@ -406,14 +410,15 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("update affiliate rebate skipped audit: %w", err) } if err := tx.Commit(); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), }) + return fmt.Errorf("commit affiliate rebate tx: %w", err) } - return + return nil } if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{ @@ -423,14 +428,16 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("update affiliate rebate applied audit: %w", err) } if err := tx.Commit(); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), }) + return fmt.Errorf("commit affiliate rebate tx: %w", err) } + return nil } func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) { @@ -444,11 +451,11 @@ func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, clien }) rows, err := client.QueryContext(ctx, ` INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) -SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW() +SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW() WHERE NOT EXISTS ( SELECT 1 FROM payment_audit_logs - WHERE order_id = $1 + WHERE order_id = $1::text AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') ) ON CONFLICT (order_id, action) DO NOTHING diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f871ee85..33316031 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -1175,6 +1175,24 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) + if settings.AffiliateRebateFreezeHours < 0 { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault + } + if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax + } + updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours) + if settings.AffiliateRebateDurationDays < 0 { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault + } + if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax + } + updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays) + if settings.AffiliateRebatePerInviteeCap < 0 { + settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault + } + updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64) updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { @@ -1512,6 +1530,54 @@ func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) floa return clampAffiliateRebateRate(rate) } +// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。 +// 返回 0 表示不冻结(向后兼容)。 +func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours) + if err != nil { + return AffiliateRebateFreezeHoursDefault + } + hours, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || hours < 0 { + return AffiliateRebateFreezeHoursDefault + } + if hours > AffiliateRebateFreezeHoursMax { + return AffiliateRebateFreezeHoursMax + } + return hours +} + +// GetAffiliateRebateDurationDays 返回返利有效期(天)。 +// 返回 0 表示永久有效。 +func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays) + if err != nil { + return AffiliateRebateDurationDaysDefault + } + days, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || days < 0 { + return AffiliateRebateDurationDaysDefault + } + if days > AffiliateRebateDurationDaysMax { + return AffiliateRebateDurationDaysMax + } + return days +} + +// GetAffiliateRebatePerInviteeCap 返回单人返利上限。 +// 返回 0 表示无上限。 +func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap) + if err != nil { + return AffiliateRebatePerInviteeCapDefault + } + cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64) + if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) { + return AffiliateRebatePerInviteeCapDefault + } + return cap +} + // IsPasswordResetEnabled 检查是否启用密码重置功能 // 要求:必须同时开启邮件验证 func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { @@ -1755,6 +1821,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), + SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault), + SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault), + SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64), SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", @@ -1890,6 +1959,21 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.AffiliateRebateRate = AffiliateRebateRateDefault } + if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 { + if freezeHours > AffiliateRebateFreezeHoursMax { + freezeHours = AffiliateRebateFreezeHoursMax + } + result.AffiliateRebateFreezeHours = freezeHours + } + if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 { + if durationDays > AffiliateRebateDurationDaysMax { + durationDays = AffiliateRebateDurationDaysMax + } + result.AffiliateRebateDurationDays = durationDays + } + if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 { + result.AffiliateRebatePerInviteeCap = perInviteeCap + } result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 70d8efc3..5ec7d313 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -104,12 +104,15 @@ type SystemSettings struct { CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints - DefaultConcurrency int - DefaultBalance float64 - AffiliateEnabled bool - AffiliateRebateRate float64 - DefaultUserRPMLimit int - DefaultSubscriptions []DefaultSubscriptionSetting + DefaultConcurrency int + DefaultBalance float64 + AffiliateEnabled bool + AffiliateRebateRate float64 + AffiliateRebateFreezeHours int + AffiliateRebateDurationDays int + AffiliateRebatePerInviteeCap float64 + DefaultUserRPMLimit int + DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` diff --git a/backend/migrations/133_affiliate_rebate_freeze.sql b/backend/migrations/133_affiliate_rebate_freeze.sql new file mode 100644 index 00000000..b87d59b7 --- /dev/null +++ b/backend/migrations/133_affiliate_rebate_freeze.sql @@ -0,0 +1,17 @@ +-- 1) Add frozen quota column to user_affiliates for rebate freeze period. +ALTER TABLE user_affiliates + ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0; + +COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)'; + +-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking. +-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp. +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL; + +COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen'; + +-- 3) Partial index for efficient thaw queries (only rows still frozen). +CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw + ON user_affiliate_ledger (user_id, frozen_until) + WHERE frozen_until IS NOT NULL; diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index a484d7ed..07a68c03 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => { }) }) + it('posts affiliate code when completing linuxdo oauth registration', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration( + 'invite-code', + { + adoptDisplayName: true, + adoptAvatar: false + }, + ' AFF123 ' + ) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + aff_code: 'AFF123', + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts oidc invitation completion with adoption decisions', async () => { const { completeOIDCOAuthRegistration } = await import('@/api/auth') @@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => { }) }) + it('posts affiliate code when creating pending wechat oauth account', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount( + 'invite-code', + { + adoptDisplayName: false, + adoptAvatar: true + }, + 'WXAFF' + ) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + aff_code: 'WXAFF', + adopt_display_name: false, + adopt_avatar: true + }) + }) + it('classifies oauth completion results as login or bind', async () => { const { getOAuthCompletionKind } = await import('@/api/auth') diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 0d98c9e9..defbab43 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -309,6 +309,9 @@ export interface SystemSettings { // Default settings default_balance: number; affiliate_rebate_rate: number; + affiliate_rebate_freeze_hours: number; + affiliate_rebate_duration_days: number; + affiliate_rebate_per_invitee_cap: number; default_concurrency: number; default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; @@ -494,6 +497,9 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; affiliate_rebate_rate?: number; + affiliate_rebate_freeze_hours?: number; + affiliate_rebate_duration_days?: number; + affiliate_rebate_per_invitee_cap?: number; default_concurrency?: number; default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index f49f3a1f..bb990fc4 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -564,9 +564,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { - return createPendingLinuxDoOAuthAccount(invitationCode, decision) + return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode) } /** @@ -576,27 +577,32 @@ export async function completeLinuxDoOAuthRegistration( */ export async function completeOIDCOAuthRegistration( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOIDCOAuthAccount(invitationCode, decision) + return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode) } export async function completeWeChatOAuthRegistration( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingWeChatOAuthAccount(invitationCode, decision) + return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode) } async function createPendingOAuthAccount( provider: 'linuxdo' | 'oidc' | 'wechat', invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { + const normalizedAffiliateCode = affiliateCode?.trim() const { data } = await apiClient.post( `/auth/oauth/${provider}/complete-registration`, { invitation_code: invitationCode, + ...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}), ...serializeOAuthAdoptionDecision(decision) } ) @@ -605,23 +611,26 @@ async function createPendingOAuthAccount( export async function createPendingLinuxDoOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('linuxdo', invitationCode, decision) + return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode) } export async function createPendingOIDCOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('oidc', invitationCode, decision) + return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode) } export async function createPendingWeChatOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('wechat', invitationCode, decision) + return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode) } export async function completePendingOAuthBindLogin( diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue index c740d06f..6b245123 100644 --- a/frontend/src/components/auth/LinuxDoOAuthSection.vue +++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue @@ -42,9 +42,11 @@ diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index bc4c6215..2f061118 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -141,7 +141,17 @@
+ +
+
+

+ {{ t("admin.settings.openaiFastPolicy.title") }} +

+

+ {{ t("admin.settings.openaiFastPolicy.description") }} +

+
+
+ +
+ {{ t("admin.settings.openaiFastPolicy.empty") }} +
+ + +
+
+ + {{ + t("admin.settings.openaiFastPolicy.ruleHeader", { + index: ruleIndex + 1, + }) + }} + + +
+ +
+ +
+ + +
+ + +
+ + +

+ {{ t("admin.settings.openaiFastPolicy.errorMessageHint") }} +

+
+ + +
+ +

+ {{ + t("admin.settings.openaiFastPolicy.modelWhitelistHint") + }} +

+
+ + +
+ +
+ + +
+ + +
+
+
+ + +
+ +

+ {{ t("admin.settings.openaiFastPolicy.saveHint") }} +

+
+
+
@@ -5199,6 +5478,7 @@ import type { SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, + OpenAIFastPolicyRule, WeChatConnectMode, WebSearchEmulationConfig, WebSearchProviderConfig, @@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({ }>, }); +// OpenAI Fast/Flex Policy 状态 +const openaiFastPolicyForm = reactive({ + rules: [] as OpenAIFastPolicyRule[], +}); +// 标记 openai_fast_policy_settings 是否已成功从后端加载, +// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。 +const openaiFastPolicyLoaded = ref(false); + const tablePageSizeMin = 5; const tablePageSizeMax = 1000; const tablePageSizeDefault = 20; @@ -6116,6 +6404,23 @@ async function loadSettings() { ); form.oidc_connect_client_secret = ""; + // Load OpenAI fast/flex policy rules from bulk settings. + // 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值, + // 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。 + if ( + settings.openai_fast_policy_settings && + Array.isArray(settings.openai_fast_policy_settings.rules) + ) { + openaiFastPolicyForm.rules = + settings.openai_fast_policy_settings.rules.map((rule) => ({ + ...rule, + model_whitelist: rule.model_whitelist + ? [...rule.model_whitelist] + : [], + })); + openaiFastPolicyLoaded.value = true; + } + // Load web search emulation config separately await loadWebSearchConfig(); } catch (error: unknown) { @@ -6460,10 +6765,39 @@ async function saveSettings() { affiliate_enabled: form.affiliate_enabled, }; + // 仅当 openai_fast_policy_settings 已成功从后端加载时才回写, + // 否则省略整个字段,让后端保留既有规则(含默认值)。 + if (openaiFastPolicyLoaded.value) { + payload.openai_fast_policy_settings = { + rules: openaiFastPolicyForm.rules.map((rule) => { + const whitelist = (rule.model_whitelist || []) + .map((p) => p.trim()) + .filter((p) => p !== ""); + const hasWhitelist = whitelist.length > 0; + return { + service_tier: rule.service_tier, + action: rule.action, + scope: rule.scope, + error_message: + rule.action === "block" ? rule.error_message : undefined, + model_whitelist: hasWhitelist ? whitelist : undefined, + fallback_action: hasWhitelist + ? rule.fallback_action || "pass" + : undefined, + fallback_error_message: + hasWhitelist && rule.fallback_action === "block" + ? rule.fallback_error_message + : undefined, + }; + }), + }; + } + appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults); const updated = await adminAPI.settings.updateSettings(payload); for (const [key, value] of Object.entries(updated)) { + if (key === "openai_fast_policy_settings") continue; if (value !== null && value !== undefined) { (form as Record)[key] = value; } @@ -6507,6 +6841,20 @@ async function saveSettings() { form.wechat_connect_mode, ); form.oidc_connect_client_secret = ""; + // Refresh OpenAI fast/flex policy from server response + if ( + updated.openai_fast_policy_settings && + Array.isArray(updated.openai_fast_policy_settings.rules) + ) { + openaiFastPolicyForm.rules = + updated.openai_fast_policy_settings.rules.map((rule) => ({ + ...rule, + model_whitelist: rule.model_whitelist + ? [...rule.model_whitelist] + : [], + })); + openaiFastPolicyLoaded.value = true; + } // Save web search emulation config separately (errors handled internally) const wsOk = await saveWebSearchConfig(); // Refresh cached settings so sidebar/header update immediately @@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() { } } +// ==================== OpenAI Fast/Flex Policy ==================== + +const openaiFastPolicyTierOptions = computed(() => [ + { value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") }, + { + value: "priority", + label: t("admin.settings.openaiFastPolicy.tierPriority"), + }, + { value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") }, +]); + +const openaiFastPolicyActionOptions = computed(() => [ + { value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") }, + { value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") }, + { value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") }, +]); + +const openaiFastPolicyScopeOptions = computed(() => [ + { value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") }, + { value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") }, + { value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") }, + { + value: "bedrock", + label: t("admin.settings.openaiFastPolicy.scopeBedrock"), + }, +]); + +function addOpenAIFastPolicyRule() { + openaiFastPolicyForm.rules.push({ + service_tier: "priority", + action: "filter", + scope: "all", + error_message: "", + model_whitelist: [], + fallback_action: "pass", + fallback_error_message: "", + }); +} + +function removeOpenAIFastPolicyRule(index: number) { + openaiFastPolicyForm.rules.splice(index, 1); +} + +function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) { + if (!rule.model_whitelist) rule.model_whitelist = []; + rule.model_whitelist.push(""); +} + +function removeOpenAIFastPolicyModelPattern( + rule: OpenAIFastPolicyRule, + idx: number, +) { + rule.model_whitelist?.splice(idx, 1); +} + async function saveBetaPolicySettings() { betaPolicySaving.value = true; try { From 04b2866f65f31c044a991e9f2c1b299927a2ac1b Mon Sep 17 00:00:00 2001 From: ivanvolt Date: Tue, 28 Apr 2026 16:26:09 +0800 Subject: [PATCH 26/46] fix: use Responses-compatible function tool_choice format --- .../pkg/apicompat/anthropic_responses_test.go | 37 ++++++++++- .../pkg/apicompat/anthropic_to_responses.go | 6 +- .../chatcompletions_responses_test.go | 2 + .../apicompat/chatcompletions_to_responses.go | 6 +- .../responses_to_anthropic_request.go | 15 ++++- .../service/openai_codex_transform.go | 62 +++++++++++++++++-- .../service/openai_codex_transform_test.go | 38 ++++++++++++ 7 files changed, 150 insertions(+), 16 deletions(-) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index facfe572..e8b25c2b 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -991,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) { var tc map[string]any require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) assert.Equal(t, "function", tc["type"]) - fn, ok := tc["function"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "get_weather", fn["name"]) + assert.Equal(t, "get_weather", tc["name"]) + assert.NotContains(t, tc, "function") +} + +func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-5.2", + Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`), + ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`), + } + + resp, err := ResponsesToAnthropicRequest(req) + require.NoError(t, err) + + var tc map[string]string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "tool", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) +} + +func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-5.2", + Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`), + ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`), + } + + resp, err := ResponsesToAnthropicRequest(req) + require.NoError(t, err) + + var tc map[string]string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "tool", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) } // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index 485262e8..268f9f22 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { // {"type":"auto"} → "auto" // {"type":"any"} → "required" // {"type":"none"} → "none" -// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} +// {"type":"tool","name":"X"} → {"type":"function","name":"X"} func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { var tc struct { Type string `json:"type"` @@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage return json.Marshal("none") case "tool": return json.Marshal(map[string]any{ - "type": "function", - "function": map[string]string{"name": tc.Name}, + "type": "function", + "name": tc.Name, }) default: // Pass through unknown types as-is diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index c140449a..35d42999 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { var tc map[string]any require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) assert.Equal(t, "function", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) + assert.NotContains(t, tc, "function") } func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index c2725406..64ef5781 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R // // "auto" → "auto" // "none" → "none" -// {"name":"X"} → {"type":"function","function":{"name":"X"}} +// {"name":"X"} → {"type":"function","name":"X"} func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { // Try string first ("auto", "none", etc.) — pass through as-is. var s string @@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, return nil, err } return json.Marshal(map[string]any{ - "type": "function", - "function": map[string]string{"name": obj.Name}, + "type": "function", + "name": obj.Name, }) } diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go index 49426b88..8fa652f2 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { // "auto" → {"type":"auto"} // "required" → {"type":"any"} // "none" → {"type":"none"} -// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} +// {"type":"function","name":"X"} → {"type":"tool","name":"X"} +// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { // Try as string first var s string @@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage // Try as object with type=function var tc struct { Type string `json:"type"` + Name string `json:"name"` Function struct { Name string `json:"name"` } `json:"function"` } - if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { + if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" { + name := strings.TrimSpace(tc.Name) + if name == "" { + name = strings.TrimSpace(tc.Function.Name) + } + if name == "" { + return raw, nil + } return json.Marshal(map[string]string{ "type": "tool", - "name": tc.Function.Name, + "name": name, }) } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index e765d7e9..0fda16b0 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -141,9 +141,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { reqBody["tool_choice"] = map[string]any{ "type": "function", - "function": map[string]any{ - "name": name, - }, + "name": name, } } } @@ -219,9 +217,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool { return false } choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) - if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { + if choiceType == "" { return false } + modified := false + if choiceType == "function" { + name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) + if name == "" { + if function, ok := choiceMap["function"].(map[string]any); ok { + name = strings.TrimSpace(firstNonEmptyString(function["name"])) + } + } + if name == "" { + reqBody["tool_choice"] = "auto" + return true + } + if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name { + choiceMap["name"] = name + modified = true + } + if _, ok := choiceMap["function"]; ok { + delete(choiceMap, "function") + modified = true + } + if !codexToolsContainFunctionName(reqBody["tools"], name) { + reqBody["tool_choice"] = "auto" + return true + } + return modified + } + if codexToolsContainType(reqBody["tools"], choiceType) { + return modified + } reqBody["tool_choice"] = "auto" return true } @@ -243,6 +270,33 @@ func codexToolsContainType(rawTools any, toolType string) bool { return false } +func codexToolsContainFunctionName(rawTools any, name string) bool { + tools, ok := rawTools.([]any) + if !ok || strings.TrimSpace(name) == "" { + return false + } + normalizedName := strings.TrimSpace(name) + for _, rawTool := range tools { + tool, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" { + continue + } + toolName := strings.TrimSpace(firstNonEmptyString(tool["name"])) + if toolName == "" { + if function, ok := tool["function"].(map[string]any); ok { + toolName = strings.TrimSpace(firstNonEmptyString(function["name"])) + } + } + if toolName == normalizedName { + return true + } + } + return false +} + func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { if len(input) == 0 { return input, false diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 75f5c55c..8d9f8574 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -249,6 +249,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) { require.Equal(t, "custom", choice["type"]) } +func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{"name": "shell"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "function", choice["type"]) + require.Equal(t, "shell", choice["name"]) + require.NotContains(t, choice, "function") +} + +func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{"name": "missing"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + require.Equal(t, "auto", reqBody["tool_choice"]) +} + func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.4", From 6327573534d903efc1d63d0e3af92d683d3d293e Mon Sep 17 00:00:00 2001 From: alfadb Date: Tue, 28 Apr 2026 19:12:48 +0800 Subject: [PATCH 27/46] fix(gateway): wrap Anthropic stream EOF as failover error before client output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Anthropic streaming path (gateway_service.go) returned a plain error on upstream SSE read failure, so the handler-level UpstreamFailoverError check never fired and the client received a bare `stream_read_error` event, breaking long-running tasks even when no bytes had been written yet. The most common trigger is HTTP/2 GOAWAY from api.anthropic.com edge backends doing graceful rotation: Go's http.Transport surfaces this as `unexpected EOF` and never auto-retries. Mirror what the OpenAI and antigravity gateways already do: when the read error happens before any byte has reached the client (`!c.Writer.Written()`), return `*UpstreamFailoverError{StatusCode: 502, RetryableOnSameAccount: true}` so the handler can retry on the same or another account. After client output has begun, SSE has no resume protocol — keep the existing passthrough behavior. Tests cover both branches via streamReadCloser-based fixtures. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/internal/service/gateway_service.go | 14 ++++ .../service/gateway_streaming_test.go | 70 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6be19ba6..911bc6fc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7041,6 +7041,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http sendErrorEvent("response_too_large") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } + // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): + // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 + // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 + if !c.Writer.Written() { + logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) + body, _ := json.Marshal(map[string]string{ + "error": fmt.Sprintf("upstream stream disconnected: %s", ev.err), + }) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + RetryableOnSameAccount: true, + } + } sendErrorEvent("stream_read_error") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index b1584827..389831fa 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -4,9 +4,11 @@ package service import ( "context" + "errors" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -218,3 +220,71 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { body := rec.Body.String() require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") } + +// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前: +// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。 +func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{err: io.ErrUnexpectedEOF}, + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + + require.Error(t, err) + require.Nil(t, result, "失败移交场景下不应返回 streamingResult") + + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试") + require.Contains(t, string(failoverErr.ResponseBody), "upstream stream disconnected") + + // 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定 + require.NotContains(t, rec.Body.String(), "stream_read_error") +} + +// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误: +// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。 +func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"), + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error") + require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult") + + // 不应被错误地包成 failover error + var failoverErr *UpstreamFailoverError + require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover") + + // 客户端必须收到 stream_read_error 事件 + body := rec.Body.String() + require.True(t, + strings.Contains(body, "stream_read_error"), + "已开始流后必须发送 stream_read_error 事件给客户端,实际响应: %q", body) +} From 4c474616b994665a104c5eeb1ccd8c5e96a31ddf Mon Sep 17 00:00:00 2001 From: alfadb Date: Tue, 28 Apr 2026 20:24:17 +0800 Subject: [PATCH 28/46] fix(gateway): emit Anthropic-standard SSE error events and failover body MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-ups to PR #2066's failover-wrap fix: 1. Failover ResponseBody (`UpstreamFailoverError.ResponseBody`) was encoded as `{"error": ""}` (string field). `ExtractUpstreamErrorMessage` probes for `error.message`, `detail`, or top-level `message` only — so `handleFailoverExhausted` and downstream passthrough rules saw an empty message, losing the EOF root cause in ops logs. Re-encode as the Anthropic standard shape `{"type":"error","error":{"type":"upstream_disconnected","message":"..."}}`. (Addresses the inline review comment from copilot-pull-request-reviewer on Wei-Shaw/sub2api#2066.) 2. The streaming `event: error` SSE frame for `response_too_large`, `stream_read_error`, and `stream_timeout` was non-standard (`{"error":""}`). Anthropic SDKs (and Claude Code) expect `{"type":"error","error":{"type":"...","message":"..."}}` and parse `error.type`/`error.message` accordingly. Refactor `sendErrorEvent` to take both reason and message, and emit the standard frame so client SDKs surface a real diagnostic message instead of a generic stream error. This does not by itself prevent task interruption on long-stream EOF (SSE has no resume; client-side retry remains the only complete fix), but it gives both server-side ops logs and client-side error UIs a meaningful upstream message so users know the next step is to retry. Tests updated to assert the new body shape on both branches plus a new assertion that `ExtractUpstreamErrorMessage` returns a non-empty string. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/internal/service/gateway_service.go | 38 +++++++++++++++---- .../service/gateway_streaming_test.go | 21 +++++++--- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 911bc6fc..4c4a9b82 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -6871,14 +6871,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } lastDataAt := time.Now() - // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。 + // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}} + // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案, + // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。 errorEventSent := false - sendErrorEvent := func(reason string) { + sendErrorEvent := func(reason, message string) { if errorEventSent { return } errorEventSent = true - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + if message == "" { + message = reason + } + body, err := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": reason, + "message": message, + }, + }) + if err != nil { + // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback + body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message)) + } + _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body) flusher.Flush() } @@ -7038,16 +7055,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") + sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize)) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 + disconnectMsg := fmt.Sprintf("upstream stream disconnected: %s", ev.err) if !c.Writer.Written() { logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) - body, _ := json.Marshal(map[string]string{ - "error": fmt.Sprintf("upstream stream disconnected: %s", ev.err), + body, _ := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": "upstream_disconnected", + "message": disconnectMsg, + }, }) return nil, &UpstreamFailoverError{ StatusCode: http.StatusBadGateway, @@ -7055,7 +7077,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http RetryableOnSameAccount: true, } } - sendErrorEvent("stream_read_error") + sendErrorEvent("stream_read_error", disconnectMsg) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line @@ -7114,7 +7136,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } - sendErrorEvent("stream_timeout") + sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval)) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index 389831fa..f3a52553 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/http/httptest" - "strings" "testing" "time" @@ -246,7 +245,15 @@ func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err) require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试") - require.Contains(t, string(failoverErr.ResponseBody), "upstream stream disconnected") + + // ResponseBody 必须是 Anthropic 标准 error 格式: + // 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖) + // 2) error.type 标记为 upstream_disconnected + extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody) + require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息") + require.Contains(t, extractedMsg, "upstream stream disconnected") + require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`) + require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`) // 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定 require.NotContains(t, rec.Body.String(), "stream_read_error") @@ -282,9 +289,11 @@ func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *tes var failoverErr *UpstreamFailoverError require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover") - // 客户端必须收到 stream_read_error 事件 + // 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error, + // error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误) body := rec.Body.String() - require.True(t, - strings.Contains(body, "stream_read_error"), - "已开始流后必须发送 stream_read_error 事件给客户端,实际响应: %q", body) + require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧") + require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)") + require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error") + require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案") } From 7452fad8205b2a5ece283ba4ae00741303e1ae00 Mon Sep 17 00:00:00 2001 From: Oganneson Date: Tue, 28 Apr 2026 20:36:50 +0800 Subject: [PATCH 29/46] fix(openai): drop reasoning items from /v1/responses input on OAuth path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #1957 The OAuth path forwards client requests to chatgpt.com/backend-api/codex/responses, where applyCodexOAuthTransform forces store=false (chatgpt.com's codex backend rejects store=true). Reasoning items emitted under store=false are NEVER persisted upstream, so any rs_* reference that a client carries forward in a subsequent input[] array triggers a guaranteed upstream 404: Item with id 'rs_...' not found. Items are not persisted when `store` is set to false. Try again with `store` set to true, or remove this item from your input. sub2api wraps this as 502 "Upstream request failed" and the conversation breaks on every multi-turn /v1/responses request that uses reasoning + tools (reproducible with gpt-5.5; gpt-5.4 happens to dodge it because the upstream does not emit reasoning items for that model). Affected clients include any that follow the OpenAI Responses API spec and replay prior assistant items verbatim — in practice this hit OpenClaw and similar agent harnesses on every turn ≥2 with tool use. The fix: in filterCodexInput, drop input items with type == "reasoning" entirely. The model never reads reasoning summary text from input (only encrypted_content can carry reasoning context across turns, and chatgpt.com under store=false does not emit it), so this is a no-op for the model itself and a clean removal of unreachable upstream lookups. Scope is intentionally narrow: * Only OAuth account requests (account.Type == AccountTypeOAuth) reach applyCodexOAuthTransform / filterCodexInput. * API-key accounts going to api.openai.com/v1/responses are unaffected (store=true works there, rs_* persists, multi-turn already works). * Anthropic / Gemini platform groups go through different transforms and are unaffected. * /v1/chat/completions is unaffected (no reasoning items). * item_reference items (different type) are unaffected — only type == "reasoning" is dropped. Verification: * Existing tests pass: go test ./internal/service/ -run Codex|Tool|OAuth * New regression test asserts reasoning items are dropped under both preserveReferences=true and preserveReferences=false. * End-to-end repro on gpt-5.5 multi-turn + tools: pre-patch 502, post-patch 200. Repro on gpt-5.4 unchanged. Three-turn deep loop on gpt-5.5 passes. --- .../service/openai_codex_transform.go | 8 +++ .../service/openai_codex_transform_test.go | 53 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index e765d7e9..59fb7a33 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -853,6 +853,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } typ, _ := m["type"].(string) + // chatgpt.com codex backend (OAuth path) does not persist reasoning + // items because applyCodexOAuthTransform forces store=false. Any rs_* + // reference replayed in input is guaranteed to 404 upstream + // ("Item with id 'rs_...' not found"). Drop reasoning items entirely. + if typ == "reasoning" { + continue + } + // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id; // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 fixCallIDPrefix := func(id string) string { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 75f5c55c..b392cf96 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,6 +1,8 @@ package service import ( + "fmt" + "strings" "testing" "github.com/stretchr/testify/require" @@ -1094,3 +1096,54 @@ func TestIsInstructionsEmpty(t *testing.T) { }) } } + +func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) { + // Reasoning items in input[] reference rs_* IDs that were emitted by + // chatgpt.com under store=false (forced by applyCodexOAuthTransform). + // They are never persisted upstream, so forwarding them produces a + // guaranteed 404 ("Item with id 'rs_...' not found"). Drop them + // regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957. + + build := func() []any { + return []any{ + map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"}, + map[string]any{ + "type": "reasoning", + "id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344", + "summary": []any{}, + }, + map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"}, + } + } + + for _, preserve := range []bool{true, false} { + preserve := preserve + t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) { + filtered := filterCodexInput(build(), preserve) + + for _, raw := range filtered { + item, ok := raw.(map[string]any) + require.True(t, ok) + require.NotEqual(t, "reasoning", item["type"], + "reasoning items must be dropped from input on the OAuth path") + if id, ok := item["id"].(string); ok { + require.False(t, strings.HasPrefix(id, "rs_"), + "no item carrying an rs_* id should survive the filter") + } + } + + // Sanity check: the non-reasoning items should still be present. + gotTypes := make(map[string]int) + for _, raw := range filtered { + item, ok := raw.(map[string]any) + require.True(t, ok) + gotTypes[item["type"].(string)]++ + } + require.Equal(t, 1, gotTypes["message"]) + require.Equal(t, 1, gotTypes["function_call"]) + require.Equal(t, 1, gotTypes["function_call_output"]) + require.Equal(t, 0, gotTypes["reasoning"]) + }) + } +} From da4b078df22c295b3dd665aea1714bcf14184bb9 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 14:41:35 +0800 Subject: [PATCH 30/46] chore: update sponsors --- README.md | 7 +++++++ README_CN.md | 7 +++++++ README_JA.md | 7 +++++++ assets/partners/logos/pateway.png | Bin 0 -> 8228 bytes 4 files changed, 21 insertions(+) create mode 100644 assets/partners/logos/pateway.png diff --git a/README.md b/README.md index 3e609d65..718730c6 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control. + +pateway +Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line. +Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information. +Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150. + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index add32a17..24600e0e 100644 --- a/README_CN.md +++ b/README_CN.md @@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。 + +pateway +感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。 +同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。 +现在通过 此链接 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。 + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index ccd595b9..1e89610c 100644 --- a/README_JA.md +++ b/README_JA.md @@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。 + +pateway +PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。 +エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。 +こちらのリンクから登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。 + + ## エコシステム diff --git a/assets/partners/logos/pateway.png b/assets/partners/logos/pateway.png new file mode 100644 index 0000000000000000000000000000000000000000..7ca3489a9248cd12668f05c73d1fe929ef1795c8 GIT binary patch literal 8228 zcmeHsiB}S9{5O^Nre@kTHE)`hsg;?gC6#ODl4zK_iCbnasJM#?qGgMjTj7?u<}QZh z0_Mu?S{j*}3!vhXp^*~d2CnFDy61iWhxa|_$AKBn%$#{jy=Tih@9( zgAfZ-I}m7ZHSqtd@SnhTKs_rG*olN%xI}0teOcJ8-&)mQXutgZ>c51K0-I;@dnkWB>(@Pb3@LP)ow_It%D5Lg`d9E* ziB12!jPV%9=bO8<@xaN>>#9#jE!aX@26ROzxy^A*n{3=2uTAwj}yj!fI#=0 zEo}Axk0&kZ033^X^uMS6|Co*J1Yj|@B_F>MO}sieF;RV7IE%K)BhSUwJuS7FCMP&U zu+(tVZ8_`O_iN}S%Ek95giM9yDcQfmcrkUdt_`Z#{0|NtU11mHm zIS{B_Lb>`FLva|My1^k=z;PZP9u6-nJ5(aOWu<$T6cKRyBG$sjMjm|!1S$cuCg0L! zx?hBY^RwF~`!}Uy<$DUrN)m}zM=C3U0amsuj ziV*xeXN zaJvK~RZP0Jy$&2tMYeI8+=e^f-9B#pIJdXrX=!yPf5cYm(dP2pf!?ycpnI7+JyH8K zsKdkdA(R?BG8RhFF7K8lVC^g7Rd}655fErS$$48x)*GdHjX-qT6>t;qXj{L=Jki7w zhBvt<=Z7)g>$1G=SRI0Z&*-SiJ}-vm!~$g}H%W;2xT@;ZDFzi* z%R^z79mzD6vPYa3Ht>_*)6 zR}1*Q6*m<~Y31;I+j^A7yzPt^$W6ohMj6XhXL`n9;s`-n1b2wy7=aCUxr~2#=!n_# zmjkQ--9nG3!m*jqY9Q>W)7X-|xwUC0>}{oLD0he2wL%jicK>9{rkQtWDC_~b!@ac| zYDN0|_lu~7nx$$1vxL%-`U9jqu6Dhe*!j^bbp4T#r z6tUC>BK!{9;$~NElwed$3Uid5tZk^gp&@PdN+#t4g**tW&Ci#U7BVjL-0hlKoaBap za?4Vjm!qSj0}P9scUpdUKtA5l&dbvHiEJB^jUZvVWlwP$eO8WXoKxP?@6{gk+iz@B z=TsZs%t&721vf`@hPb(->ll`Y9;z#^4bFwT)H$7Hlj*SLsj189_+~}3=PLp<{N34| zqizC8S9KXD!s3d)5yOOKNtBfQ2^5Rc({|0x&EM*~5&d_lFtP@(O)-_bIr&N;v@5Th zd&Maj^2>2}-eFZ;AQ1 z^@HPY;`5=CIhO z>>&G?kUXztX50R&j5|%PRffFZx*m8=?e^5&`n~hIKcIBO+#A020uLZO$h$jRywL(Z z^R1zaj%Kz%q4mtjUF{)8>Xd8H0?FmPZ;0ZWEzS+>`v65FeSMpz-Mv@1-5u}0z;a!3 zT%x&`99g`3vpf^Uaso;7Ol#bSUGTr>r8sMI;c&qiLc(*5gPaR*QSUncb5)O~axA!% zsL{g=3#7;A_yEVHh$h-2CxsKoS!1XU8BA`UuRYoo7RcrVHM_D05L*g`J}w)pt9LQ& zO|(*`PVqqDWbUUS&SLFx;n$3oUHz?<*}lKazp3{!=A3q9<%(fmBiNrh#X)CQMrX>Oky*jvI5pvJb|uVwHhr&@zlD zhr`f6Z(x`XltR1B1i=ChC{uRo$S@juT_t)wTtW{9M9{ z9e5^|JnCG9GSyKp7pDX+0`DIvhseL7s@B`kX!55r2Z5d|YirYElLec+z(5ARi;qs& z;&Yn4gky_bJr4stVroJnOhF{TWM#t*S~YBKs0X#zkrg|N>AtJIbjQwan0tR3(!IG6 zX+ne50{N;1S?KC|@Ra$s86zun|0#2QcG^O+sFh+pz59zx-X-n07JsVz+q*x9mzl1U zld5HCrtQ&kTC^&eT+W|=g&g0x?KeN%do6K1k#e{67nkb_6*`jgsI{X5(XQ@qE48q* zGc~_l3N%y?WeVLiTpwd;bGDX>Lx#9|KJI#z>9KoU*t%}76Lg~9HrlsU*=yp|d_BVt zt+Ad}-~?2n7Pq2pIf`5cDZU+ou5dRz88nbzdRp8g6U!dyp>$nN6(XwS_n9C3^FcQO zg3B96bXwmSZd!RI++$&N|DS#xanm)J<-^$5iQQ(`?JQ?ZI@T=Ov`m0};|@&w>7y zvYvPoYieL@QH;>k z*{IhS;Xw^3m%>@gAwxy9&DpASXIrV#+151rHmUf;-(6#WP&UFBTGYnKBc>hl{eA)N zfdE993BA?NTG38thPO@Txa4F5Uj~Qdrja}6e{q4>bp4P}M^;aOHhNvMjO()9HDeR0 zA7r%lMSW+XAdR`6;g8Zdb{0AtE|XB0J^JmkVeC2QwY?Ov?FM}hnd^3)7P@-qi${~m zyN2_`5Z8WOf7A+pdxC=`tW7twmXU3EB8<^#?~|&`iX0x>c6T)Up%lbD!AehXuJ~mA zN8`Gl1nb~&l(={5x-mSknL+dj{3Th+6>`LyS|}j9X5PrBna_n@mJ)=DcD-@KJAHfT zxOx1Sca1%yC4Vy5Sv;w$KYg-h_8oubErVm)s}IH(t3#hOGw@WF<-*3VVoh7AY~w?> zxcj_s{B3577ldkzH45r{#$NmSO=6*#q;h^u-%_r8Be=+3xw58^X!~il z*=Tg4S`B{GDdavU{+a&6UOWE1pSs4`{36=9havgpMY_()$#t_{6HP~Kv4^bq&8taQ z<^D3Wt<_0VVTBYvY!_v0ZK_ZWi7mA8%_Tp%|9k(FVkS#Ct)v^Ohgs7f99IiVEQ3=q zMh{c9N|RcZmwo^vNP$C8os3^)-XiVgYZN{a?P_v1=m5C+U@q36xQB z{;gAa2#@HQeULo0+{`i{^?0VHj6gHHL_FnR$sa)D_i+61@_O0xqhfyo5XK3!xeCFo zFV6f-dtLiMbg|05@sY%HDPBlQ3WD?w>@^+)PCKm1nEg6pGJ*WwfN;0Z7ZU93bP_pZ zb&)D(YN=p|9o!_G@Y?<`@z>KEn9R6K5UMC8!+X9%N}y!?DBE5_I7f>!5Toy8E0w&H zmf|5<+`H#k8Ikm>4{_3~0s0L9opGyV2#Ohms_o~Uu7C$Pj!cqh<+J;Z2z!lebH{%h zB1v2;P&9f=z?V;}E^VwYF)0Zlj2y6Jx0-De1k}%aHm2&K;mM6QYiKQ+)oJ2<5v7Ze z@3su;BC=P->Fz~GCbyEV+M#yfV|{1AipViBP}iV` zcA9p}X8X8kGH?v5hGKQ+UDt1tp81eY=BtR#qe?dJ&0D)|D5$4janV*=$aJTR!Qz05 zCYS=vRv(=}_orzRp_MQBtByS-tyPwf!ySSc0ksR#qzzje;?@L)ChJ|e+O|D&ROo(` zUIk(8ELG?l!%NjmR`=$3*ozgrz1%CQT5~b_b&lxsy$t~M;W@?6-NN*)kpf2QnLp$p ze=B?6xNc^WCG?m0Ms&HoEKUIn$h00}Uos|iL%q!Xi~l?REd?~gtI=~>X1aO4UXOOc z$F*b{sZ()YOpi+;sa#KbM}KcR`OVv-Y}yX0LU;g${dU=DWuraQ zbG=)Z=%zsp0tX5jvvsHJ%ZT6URj2+#-e2z$^0mxAgxw&_KU&{dVU3aN3oRmlzF?%* z$Dsv`)G($O6~EM?x~L{Xp#szWmU)~Hcu8u^+b|ButXpuraTFefSdo(1U^`Bb;~@U~vn7L@zWp(EzTMxqjJ3 zrIPiVJ@O<&Q#Dec&tn+hxJjaX#t>ai#fGcBVA{y7dH8$C-!yLC=~rJe^A)V)DxNMhpn-u=VT6KSJuSjV`wNuN_i1? zC;TZwVhSra> z%8mZ+)Du}F4})}>4h=?ec^^VtuQ&AjZAFR)H#6u7WkkcRH5N|Tg8gdzbKDdMd8+wL z&4~be>c4ZbRat5PUD??dVBitkILz1OiL4RBD5#y$OPB4*KfP`AW{20SRWEGSqH66j zDW*Oy-7l_7SkP8I@J`k}xE7^($CWT`GvfHfm90$Up#J2D3-cAdMdLv$6Ko75WP-^o zTQf-?x=4&?HW}Th>I)EdZ!MoO%snqP(*fXd3n?=X6<=z2aFlv(@Az~DYA!*95M7sd zcegG6MT^q8816}iw=;3tcYD4?aHnw^=dwk*P%)5}+ti^5@ujr+qJQ!|=NX+-(O2rU z#Y_lAk68YAO|@5@zJMaMRI}D(R|Q49);5Kg^?pN#WS)a-R?VUr%1Dh-Rd2^cRPc5U zG|}(EMrCV~1?{%c+zM;JsW<}SE#h38?Wno46BN%73DzLM8Irdw1vvl3nXV-3Ys)5P zL%xxKrm?Sc=S~*Cl;q<523G!0k8>IK{%l;9_3>QK4O5VslMGzy*g_Tuey*p-X@u!!xZs{WLwN5M`x(1qX}tX>r5Nh1US3 zB1|3-Oz;3p2V}2ObaGR;MXo^y_L520%$HvFmO;4TA^!NktUWW}UR5<-;J;=^toB>} zS>kE8ZV0i|_9Dzbl$>uFrYhIy9^S_S=n-oRrNOYi*5HC%9v!L<>GI7RJ^l2x>cW_M zNPylIpM-0-CZAa!LNuIXP&@$@j?Eii2tg4109A;@dNkWP++D9L4GSEUWP+!kL+pNI zIXCz92Gvz9mM=XzyC7$7L#IMhC+?<=N@G#}gax}m56gl0j?*_3umXo19*s*9!AEfh za$a-e(-GK^2)Cfj2Q#5DNHka^jM%5TmYSLuXaas0@)e9GPM$Zz&iW`Gicb14ZE9r_9(Q?FvC{1tX7)#Q zCN-J+%J<;YF41ANfX)KUs7nsjP@F{A80h%-wamjeUKhlv{OlS;+HE=N*}yY+;>$i? zYn)5d704jJaY4*V@6rwx`KZ=XF>_40>sU~dO^IuN<7}WPTp~3Pn4#Xe1_0%k(iNdW zU^!M#4QH`qKb7Qh$EU-?wd3drEZ0^_XMy*8Ro!C2EN70OQ0_g=ye0AMuB@>N(1rm} z|5#eq`x8JP(tDTyeK4vsf^g{!e}CdM)ukxef+m;9U%h;nXmQ}8MjqxayF*Wn>i^Ib?{kwi90nYh`}_F7pDT0 z7SLpKb91Lrw~4d`Ak@`__p%#|5zqV0(@QPPY*%}*Ij}$?Aj&eCOvPPa!{S0+%wWz&Wv4;>i+5iDu!_}91x5}V!{;!Xrw)%OLyn%RCNVy*xWy@0_+_;szT}g+_Qt7WHNcw>_AYifx-8!HEu_U{brZ9VeuL=uBf;eJ=Wi;DiY&K-#9Up z{07r8MUXN(G2Yxs(~p<1uyHj{L`|uz-EeQ+h@0AA;|QK)%D8fjGCs7TF(c@=NLxy#GbhiWWW_g z=kF^U@3i%7rJ_2&{G*e@SnT*b@~Yb2dj7@b^xf0yA+D{Jl#MU5dKSHxd0~jGIRfTN z-i@Q0{gh5@O18k3*SS-XP=a{}xMGv|5 zkw(z9h5$Dg1&-)#=2Jw2TM4*D23gh0~C3cM@Slv|%w_cdBi&4oJ!CirS>D z^7%A|tbMIa;G}LD;49QpKLw0g%K_vZNLzsPrsd{gdil|_02~tjRXL|R(jAN7(rU`}@ zr^Uzh%9ee*9NeLtRlDSNSyMQXPW3cQ*xBxRf3NhkN1#O|d;u|^jRUh^p~|SW1OMsH zX+4?&cu)2_C*H*jXmvHIgDk=D-zR>cYmxU>9@kO8YS{!JAiL!!b5Y$p^;5&=O)i4V zt#a6U-5Q@|N6^6xU-S{*gPPN~>Le}3>YRZ{KpE3rJWXG~KCpVa6f-dW zc|gxIuN`tjwJ3moxXSk|5Xd_XJ6nd9!Fw?zglzaFvxf5TbMFb$x_kc-_}nUwcxwgwK!CxE?iSsum?*H@t9u3kL_3QK#Z zzLUOjzAQnv5C>qP1;58$>$(hv?22yh3Xod?jbn?tiXic7Cspn4Y9RX6@I*kaKz_*5 z_{F^kBoz8kPgt2&0aRV)0KfqXg5*Gpdhq#bdjL^;cE-Wy8bQ;~O88d;@gxRjp1SbwYf-R-NCzd&!V0orgMD}vq!(q+X}($FcbD*64ChSa-du*z%!ci0ZKkKfNFWCmBIN$xoa1IFHF$waNw zws^hlujDg;8};bnMdLK)Tfh^GZ4)Ww-j*Q)TuM#ML|W*6YfFnQEUCoS$ZiEiZn&r|7ruT5<{&sFUC4vAZ@xpP2@#U5j9}fel790q5Hq5{-2+XPrG~W X>0Tf#^eJP21p`6MtWE2#-A?=;CF9gr literal 0 HcmV?d00001 From 4b6954f9f05876b07c03f62e8492e061f9cb7bfb Mon Sep 17 00:00:00 2001 From: erio Date: Wed, 29 Apr 2026 15:01:02 +0800 Subject: [PATCH 31/46] feat(ops): allow retention days = 0 to wipe table on each scheduled cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Background / 背景 The ops cleanup task currently rejects retention days < 1 in both validate and normalize, so operators who want minimal-history setups (e.g. high churn deployments that prefer near-realtime cleanup) cannot express that intent through the UI. The only options are 1+ days, which keeps at least 24h of history regardless of cron frequency. ops 清理任务目前在 validate 和 normalize 两处都拒绝小于 1 的保留天数, 让希望尽量不留历史的运维场景(高吞吐部署 + 想用近实时清理)无法通过 UI 表达。最低只能配 1,等于不管 cron 多频繁,至少都会保留 24 小时的历史。 Purpose / 目的 Let admins set retention days to 0, meaning "every scheduled cleanup run wipes the corresponding table(s) entirely". Combined with a more frequent cron (e.g. `0 * * * *`) this yields effectively rolling cleanup. 允许管理员把保留天数设为 0,语义为"每次定时清理时把对应表全部清空"。 搭配更频繁的 cron(比如每小时整点)即可获得近似滚动清理的效果。 Changes / 改动内容 Backend - service/ops_settings.go: validate accepts [0, 365]; normalize only refills default 30 when value is < 0 (negative is treated as legacy bad data, 0 is honoured) - service/ops_cleanup_service.go: introduce `opsCleanupPlan(now, days)` returning `(cutoff, truncate, ok)`. days==0 returns truncate=true and short-circuits to a new `truncateOpsTable` helper that uses `TRUNCATE TABLE` (O(1), no WAL, no VACUUM pressure). days>0 keeps the existing batched DELETE path unchanged. Empty tables skip TRUNCATE to avoid the ACCESS EXCLUSIVE lock entirely - Extract `isMissingRelationError` helper to dedupe the "table not yet created" tolerance shared by both delete and truncate paths - Add unit tests for `opsCleanupPlan` (three branches) and `isMissingRelationError` 后端 - service/ops_settings.go: validate 接受 [0, 365];normalize 仅在 < 0 时回填默认 30(负数视为脏数据,0 被尊重) - service/ops_cleanup_service.go: 抽 `opsCleanupPlan(now, days)` 返回 `(cutoff, truncate, ok)`。days==0 → truncate=true,走新增 `truncateOpsTable`(TRUNCATE TABLE,O(1),无 WAL、无 VACUUM 压力); days>0 仍走原批量 DELETE 路径,行为完全不变。空表跳过 TRUNCATE, 避免无意义的 ACCESS EXCLUSIVE 锁 - 抽 `isMissingRelationError` helper 复用 delete / truncate 两处的 "表不存在"宽容判断 - 补 `opsCleanupPlan` 三分支 + `isMissingRelationError` 单元测试 Frontend - OpsSettingsDialog.vue: validation accepts [0, 365]; input min=0 - i18n (zh/en): hint mentions "0 = wipe all on every cleanup", validation message updated to 0-365 range 前端 - OpsSettingsDialog.vue: 校验放宽到 [0, 365],input min 改 0 - i18n(zh/en):hint 补"0 = 每次清理时清空所有",错误提示改 0-365 Trade-offs / 取舍 - TRUNCATE requires ACCESS EXCLUSIVE lock briefly, but ops tables only have the cleanup task as a writer, so the lock is invisible to other workloads - Empty-table guard avoids the lock when there is nothing to clean - Negative values are still treated as legacy bad data and replaced with default 30 to preserve compatibility --- .../internal/service/ops_cleanup_service.go | 97 ++++++++++++++++--- .../service/ops_cleanup_service_test.go | 64 ++++++++++++ backend/internal/service/ops_settings.go | 21 ++-- frontend/src/i18n/locales/en.ts | 4 +- frontend/src/i18n/locales/zh.ts | 4 +- .../ops/components/OpsSettingsDialog.vue | 12 +-- 6 files changed, 167 insertions(+), 35 deletions(-) create mode 100644 backend/internal/service/ops_cleanup_service_test.go diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 08a10a02..44ec1ad1 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string { ) } +// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。 +// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据 +// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true +// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天 +// +// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE": +// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成 +// - 无 WAL 写入、无后续 VACUUM 压力 +// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略 +func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) { + if days < 0 { + return time.Time{}, false, false + } + if days == 0 { + return time.Time{}, true, true + } + return now.AddDate(0, 0, -days), false, true +} + func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) { out := opsCleanupDeletedCounts{} if s == nil || s.db == nil || s.cfg == nil { @@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet now := time.Now().UTC() - // Error-like tables: error logs / retry attempts / alert events. - if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false) + // runOne 把"truncate? cutoff? batched delete?"封装到一处, + // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。 + runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) { + if truncate { + return truncateOpsTable(ctx, s.db, table) + } + return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate) + } + + // Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits. + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false) if err != nil { return out, err } out.errorLogs = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false) if err != nil { return out, err } out.retryAttempts = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false) if err != nil { return out, err } out.alertEvents = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false) if err != nil { return out, err } out.systemLogs = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false) if err != nil { return out, err } @@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet } // Minute-level metrics snapshots. - if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false) + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false) if err != nil { return out, err } @@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet } // Pre-aggregation tables (hourly/daily). - if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false) + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false) if err != nil { return out, err } out.hourlyPreagg = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true) + n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true) if err != nil { return out, err } @@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch) res, err := db.ExecContext(ctx, q, cutoff, batchSize) if err != nil { // If ops tables aren't present yet (partial deployments), treat as no-op. - if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") { + if isMissingRelationError(err) { return total, nil } return total, err @@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch) return total, nil } +// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。 +// +// 与 deleteOldRowsByID 的差异: +// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义 +// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力 +// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略 +// +// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。 +func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) { + if db == nil { + return 0, nil + } + var count int64 + if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("count %s: %w", table, err) + } + if count == 0 { + return 0, nil + } + if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("truncate %s: %w", table, err) + } + return count, nil +} + +// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。 +func isMissingRelationError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + return strings.Contains(s, "does not exist") && strings.Contains(s, "relation") +} + func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) { if s == nil { return nil, false diff --git a/backend/internal/service/ops_cleanup_service_test.go b/backend/internal/service/ops_cleanup_service_test.go new file mode 100644 index 00000000..86657d27 --- /dev/null +++ b/backend/internal/service/ops_cleanup_service_test.go @@ -0,0 +1,64 @@ +package service + +import ( + "testing" + "time" +) + +func TestOpsCleanupPlan(t *testing.T) { + now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + + cases := []struct { + name string + days int + wantOK bool + wantTruncate bool + wantCutoff time.Time + }{ + {name: "negative skips", days: -1, wantOK: false}, + {name: "zero truncates", days: 0, wantOK: true, wantTruncate: true}, + {name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cutoff, truncate, ok := opsCleanupPlan(now, tc.days) + if ok != tc.wantOK { + t.Fatalf("ok = %v, want %v", ok, tc.wantOK) + } + if !ok { + return + } + if truncate != tc.wantTruncate { + t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate) + } + if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) { + t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff) + } + }) + } +} + +func TestIsMissingRelationError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {name: "nil is not missing", err: nil, want: false}, + {name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true}, + {name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true}, + {name: "non-matching error", err: fakeErr("connection refused"), want: false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := isMissingRelationError(tc.err); got != tc.want { + t.Fatalf("got %v, want %v", got, tc.want) + } + }) + } +} + +type fakeErr string + +func (e fakeErr) Error() string { return string(e) } diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index 5871166c..ecc3a94b 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) { if cfg.DataRetention.CleanupSchedule == "" { cfg.DataRetention.CleanupSchedule = "0 2 * * *" } - if cfg.DataRetention.ErrorLogRetentionDays <= 0 { + // 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留; + // 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。 + if cfg.DataRetention.ErrorLogRetentionDays < 0 { cfg.DataRetention.ErrorLogRetentionDays = 30 } - if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 { + if cfg.DataRetention.MinuteMetricsRetentionDays < 0 { cfg.DataRetention.MinuteMetricsRetentionDays = 30 } - if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 { + if cfg.DataRetention.HourlyMetricsRetentionDays < 0 { cfg.DataRetention.HourlyMetricsRetentionDays = 30 } // Normalize auto refresh interval (default 30 seconds) @@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error { if cfg == nil { return errors.New("invalid config") } - if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 { - return errors.New("error_log_retention_days must be between 1 and 365") + // 保留天数:0 表示每次清理全部,1-365 表示按天数保留。 + if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 { + return errors.New("error_log_retention_days must be between 0 and 365") } - if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 { - return errors.New("minute_metrics_retention_days must be between 1 and 365") + if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 { + return errors.New("minute_metrics_retention_days must be between 0 and 365") } - if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 { - return errors.New("hourly_metrics_retention_days must be between 1 and 365") + if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 { + return errors.New("hourly_metrics_retention_days must be between 0 and 365") } if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 { return errors.New("auto_refresh_interval_seconds must be between 15 and 300") diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c66ca55b..270cd660 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -4648,7 +4648,7 @@ export default { errorLogRetentionDays: 'Error Log Retention Days', minuteMetricsRetentionDays: 'Minute Metrics Retention Days', hourlyMetricsRetentionDays: 'Hourly Metrics Retention Days', - retentionDaysHint: 'Recommended 7-90 days, longer periods will consume more storage', + retentionDaysHint: 'Recommended 7-90 days; longer periods consume more storage. Set to 0 to wipe all history on every scheduled cleanup', aggregation: 'Pre-aggregation Tasks', enableAggregation: 'Enable Pre-aggregation', aggregationHint: 'Pre-aggregation improves query performance for long time windows', @@ -4678,7 +4678,7 @@ export default { autoRefreshCountdown: 'Auto refresh: {seconds}s', validation: { title: 'Please fix the following issues', - retentionDaysRange: 'Retention days must be between 1-365 days', + retentionDaysRange: 'Retention days must be between 0 and 365 (0 = wipe all on every cleanup)', slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100', ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0', requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 77d1c93c..fdfc9e41 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -4810,7 +4810,7 @@ export default { errorLogRetentionDays: '错误日志保留天数', minuteMetricsRetentionDays: '分钟指标保留天数', hourlyMetricsRetentionDays: '小时指标保留天数', - retentionDaysHint: '建议保留7-90天,过长会占用存储空间', + retentionDaysHint: '建议保留 7-90 天,过长会占用存储空间;填 0 表示每次定时清理时清空所有历史', aggregation: '预聚合任务', enableAggregation: '启用预聚合任务', aggregationHint: '预聚合可提升长时间窗口查询性能', @@ -4841,7 +4841,7 @@ export default { autoRefreshCountdown: '自动刷新:{seconds}s', validation: { title: '请先修正以下问题', - retentionDaysRange: '保留天数必须在1-365天之间', + retentionDaysRange: '保留天数必须在 0-365 天之间(0 = 每次清理时清空所有)', slaMinPercentRange: 'SLA最低百分比必须在0-100之间', ttftP99MaxRange: 'TTFT P99最大值必须大于等于0', requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间', diff --git a/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue b/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue index 542f111d..5dba5b1d 100644 --- a/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue +++ b/frontend/src/views/admin/ops/components/OpsSettingsDialog.vue @@ -136,13 +136,13 @@ const validation = computed(() => { // 验证高级设置 if (advancedSettings.value) { const { error_log_retention_days, minute_metrics_retention_days, hourly_metrics_retention_days } = advancedSettings.value.data_retention - if (error_log_retention_days < 1 || error_log_retention_days > 365) { + if (error_log_retention_days < 0 || error_log_retention_days > 365) { errors.push(t('admin.ops.settings.validation.retentionDaysRange')) } - if (minute_metrics_retention_days < 1 || minute_metrics_retention_days > 365) { + if (minute_metrics_retention_days < 0 || minute_metrics_retention_days > 365) { errors.push(t('admin.ops.settings.validation.retentionDaysRange')) } - if (hourly_metrics_retention_days < 1 || hourly_metrics_retention_days > 365) { + if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) { errors.push(t('admin.ops.settings.validation.retentionDaysRange')) } } @@ -431,7 +431,7 @@ async function saveAllSettings() { @@ -441,7 +441,7 @@ async function saveAllSettings() { @@ -451,7 +451,7 @@ async function saveAllSettings() { From d78478e8668f0547f9639c812f2bb2641f80166f Mon Sep 17 00:00:00 2001 From: alfadb Date: Wed, 29 Apr 2026 15:44:54 +0800 Subject: [PATCH 32/46] fix(gateway): sanitize stream errors to avoid leaking infrastructure topology (*net.OpError).Error() concatenates Source/Addr fields, so the previous disconnectMsg surfaced internal source IP/port and upstream server address to clients via SSE error frames and UpstreamFailoverError.ResponseBody (reported by @Wei-Shaw on PR #2066). - Add sanitizeStreamError that maps known errors (io.ErrUnexpectedEOF, context.Canceled, syscall.ECONNRESET/EPIPE/ETIMEDOUT/...) to fixed descriptions and falls back to a generic placeholder, with an explicit *net.OpError branch that drops Source/Addr fields entirely. - Use sanitized message in client-facing disconnectMsg; full ev.err is still preserved in the existing operator log line for diagnosis. - Tests cover net.OpError redaction, the failover ResponseBody path, and every known sanitized error mapping. --- backend/internal/service/gateway_service.go | 50 +++++++++- .../service/gateway_streaming_test.go | 96 +++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4c4a9b82..aea0ba94 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -11,6 +11,7 @@ import ( "io" "log/slog" mathrand "math/rand" + "net" "net/http" "net/url" "os" @@ -20,6 +21,7 @@ import ( "strconv" "strings" "sync/atomic" + "syscall" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -6434,6 +6436,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { return false } +// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。 +// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 +// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection +// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。 +func sanitizeStreamError(err error) string { + if err == nil { + return "" + } + switch { + case errors.Is(err, io.ErrUnexpectedEOF): + return "unexpected EOF" + case errors.Is(err, io.EOF): + return "EOF" + case errors.Is(err, context.Canceled): + return "canceled" + case errors.Is(err, context.DeadlineExceeded): + return "deadline exceeded" + case errors.Is(err, syscall.ECONNRESET): + return "connection reset by peer" + case errors.Is(err, syscall.ECONNABORTED): + return "connection aborted" + case errors.Is(err, syscall.ETIMEDOUT): + return "connection timed out" + case errors.Is(err, syscall.EPIPE): + return "broken pipe" + case errors.Is(err, syscall.ECONNREFUSED): + return "connection refused" + } + var netErr *net.OpError + if errors.As(err, &netErr) { + if netErr.Timeout() { + if netErr.Op != "" { + return netErr.Op + " timeout" + } + return "i/o timeout" + } + if netErr.Op != "" { + return netErr.Op + " network error" + } + } + return "upstream connection error" +} + // ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 // 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} func ExtractUpstreamErrorMessage(body []byte) string { @@ -7061,7 +7106,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 - disconnectMsg := fmt.Sprintf("upstream stream disconnected: %s", ev.err) + // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址, + // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err + // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。 + disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err) if !c.Writer.Written() { logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) body, _ := json.Marshal(map[string]any{ diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index f3a52553..ef09a882 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -6,8 +6,10 @@ import ( "context" "errors" "io" + "net" "net/http" "net/http/httptest" + "syscall" "testing" "time" @@ -297,3 +299,97 @@ func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *tes require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error") require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案") } + +// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 +// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过 +// failover ResponseBody 或 SSE error 帧返回给客户端。 +func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) { + src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321") + require.NoError(t, err) + dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443") + require.NoError(t, err) + + raw := &net.OpError{ + Op: "read", + Net: "tcp", + Source: src, + Addr: dst, + Err: syscall.ECONNRESET, + } + + // 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过) + require.Contains(t, raw.Error(), "10.0.0.1") + require.Contains(t, raw.Error(), "52.1.2.3") + + got := sanitizeStreamError(raw) + require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP") + require.NotContains(t, got, "54321", "不得泄露源端口") + require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP") + require.NotContains(t, got, "443", "不得泄露上游端口") + require.Equal(t, "connection reset by peer", got) +} + +func TestSanitizeStreamError_KnownErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"}, + {"EOF", io.EOF, "EOF"}, + {"context canceled", context.Canceled, "canceled"}, + {"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"}, + {"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"}, + {"EPIPE", syscall.EPIPE, "broken pipe"}, + {"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"}, + {"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"}, + {"nil 返回空串", nil, ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.want, sanitizeStreamError(tc.err)) + }) + } +} + +// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志 +// 时携带内部地址信息。 +func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321") + dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443") + netErr := &net.OpError{ + Op: "read", + Net: "tcp", + Source: src, + Addr: dst, + Err: syscall.ECONNRESET, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{err: netErr}, + } + + _, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr)) + + body := string(failoverErr.ResponseBody) + require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP") + require.NotContains(t, body, "54321") + require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP") + require.NotContains(t, body, "443") + // 仍然包含可诊断的根因 + require.Contains(t, body, "connection reset by peer") + require.Contains(t, body, "upstream stream disconnected") +} From 93d91e20b9da4dd1986085ee62fb716213df0c5b Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 16:53:09 +0800 Subject: [PATCH 33/46] fix(vertex): audit fixes for Vertex Service Account feature (#1977) - Security: force token_uri to Google default, preventing SSRF via crafted service account JSON - Dedup: extract shared getVertexServiceAccountAccessToken() to eliminate ~35 lines of duplication between ClaudeTokenProvider and GeminiTokenProvider - Fix: apply model mapping + Vertex model ID normalization in forward_as_responses and forward_as_chat_completions paths - Fix: exclude service_account from AI Studio endpoint selection (Vertex cannot serve generativelanguage.googleapis.com) - Feature: add model restriction/mapping UI for service_account in EditAccountModal - Dedup: extract VERTEX_LOCATION_OPTIONS to shared constants - i18n: replace all hardcoded Chinese strings in Vertex UI with translation keys --- .../internal/service/claude_token_provider.go | 37 +-- .../gateway_forward_as_chat_completions.go | 9 +- .../service/gateway_forward_as_responses.go | 9 +- .../service/gemini_messages_compat_service.go | 4 + .../internal/service/gemini_token_provider.go | 37 +-- .../service/vertex_service_account.go | 48 +++- .../components/account/CreateAccountModal.vue | 71 +---- .../components/account/EditAccountModal.vue | 266 ++++++++++++++---- frontend/src/constants/account.ts | 48 ++++ frontend/src/i18n/locales/en.ts | 20 ++ frontend/src/i18n/locales/zh.ts | 20 ++ 11 files changed, 378 insertions(+), 191 deletions(-) diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index 9292979f..d70379c1 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -162,40 +162,5 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou } func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { - key, err := parseVertexServiceAccountKey(account) - if err != nil { - return "", err - } - cacheKey := vertexServiceAccountCacheKey(account, key) - - if p.tokenCache != nil { - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - } - - locked := false - if p.tokenCache != nil { - var lockErr error - locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if lockErr == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - } else if lockErr != nil { - slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) - } else { - time.Sleep(claudeLockWaitTime) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - } - } - - accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) - if err != nil { - return "", err - } - if p.tokenCache != nil { - _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) - } - return accessToken, nil + return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account) } diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index c531667e..7ac77f77 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions( // 4. Model mapping mappedModel := originalModel - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } - if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel)) + if normalized != originalModel { + mappedModel = normalized + } + } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(originalModel) if normalized != originalModel { mappedModel = normalized diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 647193d6..8f8a1e94 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses( // 4. Model mapping mappedModel := originalModel reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } - if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel)) + if normalized != originalModel { + mappedModel = normalized + } + } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(originalModel) if normalized != originalModel { mappedModel = normalized diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 20293ac8..ea0c0d7d 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont } // Code Assist OAuth tokens often lack AI Studio scopes for models listing. return 3 + case AccountTypeServiceAccount: + // Vertex service accounts use aiplatform.googleapis.com, not the AI Studio + // endpoint (generativelanguage.googleapis.com), so they cannot serve these requests. + return 999 default: return 10 } diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index c22f2131..172b9411 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -172,42 +172,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { - key, err := parseVertexServiceAccountKey(account) - if err != nil { - return "", err - } - cacheKey := vertexServiceAccountCacheKey(account, key) - - if p.tokenCache != nil { - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - } - - locked := false - if p.tokenCache != nil { - var lockErr error - locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) - if lockErr == nil && locked { - defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() - } else if lockErr != nil { - slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) - } else { - time.Sleep(200 * time.Millisecond) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { - return token, nil - } - } - } - - accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) - if err != nil { - return "", err - } - if p.tokenCache != nil { - _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) - } - return accessToken, nil + return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account) } func GeminiTokenCacheKey(account *Account) string { diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go index d4130b93..4430cf81 100644 --- a/backend/internal/service/vertex_service_account.go +++ b/backend/internal/service/vertex_service_account.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "net/url" "regexp" @@ -23,6 +24,7 @@ const ( vertexDefaultTokenURL = "https://oauth2.googleapis.com/token" vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" vertexServiceAccountCacheSkew = 5 * time.Minute + vertexLockWaitTime = 200 * time.Millisecond vertexAnthropicVersion = "vertex-2023-10-16" ) @@ -123,9 +125,8 @@ func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) if strings.TrimSpace(key.ProjectID) == "" { return nil, errors.New("service account json missing project_id") } - if strings.TrimSpace(key.TokenURI) == "" { - key.TokenURI = vertexDefaultTokenURL - } + // Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri. + key.TokenURI = vertexDefaultTokenURL return &key, nil } @@ -141,6 +142,47 @@ func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey return "vertex:service_account:" + fingerprint } +// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account, +// using the shared cache and distributed lock to avoid redundant exchanges. +func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) { + key, err := parseVertexServiceAccountKey(account) + if err != nil { + return "", err + } + cacheKey := vertexServiceAccountCacheKey(account, key) + + if cache != nil { + if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + locked := false + if cache != nil { + var lockErr error + locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { + time.Sleep(vertexLockWaitTime) + if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + } + + accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) + if err != nil { + return "", err + } + if cache != nil { + _ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + return accessToken, nil +} + func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) { now := time.Now() claims := jwt.MapClaims{ diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e7a790ec..d38c31c5 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -276,7 +276,7 @@ v-if="accountCategory === 'service_account'" class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200" > -

使用 Google Cloud Service Account JSON 通过 Vertex AI 调用 Anthropic Claude。建议配置模型映射,将客户端 Claude 模型名映射到 Vertex 模型 ID。

+

{{ t('admin.accounts.vertexAnthropicHint') }}

@@ -479,7 +479,7 @@ v-if="accountCategory === 'service_account'" class="mt-3 rounded-lg border border-sky-200 bg-sky-50 px-3 py-2 text-xs text-sky-800 dark:border-sky-800/40 dark:bg-sky-900/20 dark:text-sky-200" > -

使用 Google Cloud Service Account JSON 访问 Vertex AI Gemini。建议将 Vertex 账号放入独立分组,避免和 AI Studio/Gemini OAuth 同模型混调。

+

{{ t('admin.accounts.vertexGeminiHint') }}

@@ -827,10 +827,10 @@
- {{ vertexClientEmail ? '已读取 Service Account JSON' : '拖入 Service Account JSON' }} + {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}

- {{ vertexClientEmail ? '密钥内容不会在表单中显示。' : '把 .json 文件拖到这里,或点击按钮选择文件。' }} + {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}

Client Email: {{ vertexClientEmail }}
-

上传或拖入 JSON 后会自动读取 project_id,密钥内容仅用于创建账号提交。

+

{{ t('admin.accounts.vertexSaJsonUploadHint') }}

@@ -861,7 +861,7 @@ type="text" class="input font-mono" readonly - placeholder="从 JSON 自动读取" + :placeholder="t('admin.accounts.vertexProjectIdPlaceholder')" />
@@ -872,7 +872,7 @@ class="input font-mono" > @@ -885,7 +885,7 @@ -

不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。

+

{{ t('admin.accounts.vertexLocationHint') }}

@@ -3132,6 +3132,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' +import { VERTEX_LOCATION_OPTIONS } from '@/constants/account' import { OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, @@ -3318,52 +3319,6 @@ const vertexProjectId = ref('') const vertexClientEmail = ref('') const vertexLocation = ref('global') const vertexServiceAccountDragActive = ref(false) -const vertexLocationOptions = [ - { - label: 'Common', - options: [ - { value: 'us-central1', label: 'us-central1 (Iowa)' }, - { value: 'global', label: 'global' }, - { value: 'us', label: 'us' }, - { value: 'eu', label: 'eu' } - ] - }, - { - label: 'United States', - options: [ - { value: 'us-east1', label: 'us-east1 (South Carolina)' }, - { value: 'us-east4', label: 'us-east4 (Northern Virginia)' }, - { value: 'us-east5', label: 'us-east5 (Columbus)' }, - { value: 'us-south1', label: 'us-south1 (Dallas)' }, - { value: 'us-west1', label: 'us-west1 (Oregon)' }, - { value: 'us-west4', label: 'us-west4 (Las Vegas)' } - ] - }, - { - label: 'Europe', - options: [ - { value: 'europe-west1', label: 'europe-west1 (Belgium)' }, - { value: 'europe-west2', label: 'europe-west2 (London)' }, - { value: 'europe-west3', label: 'europe-west3 (Frankfurt)' }, - { value: 'europe-west4', label: 'europe-west4 (Netherlands)' }, - { value: 'europe-west6', label: 'europe-west6 (Zurich)' }, - { value: 'europe-west8', label: 'europe-west8 (Milan)' }, - { value: 'europe-west9', label: 'europe-west9 (Paris)' } - ] - }, - { - label: 'Asia Pacific', - options: [ - { value: 'asia-east1', label: 'asia-east1 (Taiwan)' }, - { value: 'asia-east2', label: 'asia-east2 (Hong Kong)' }, - { value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' }, - { value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' }, - { value: 'asia-south1', label: 'asia-south1 (Mumbai)' }, - { value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' }, - { value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' } - ] - } -] as const const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -4251,7 +4206,7 @@ const applyVertexServiceAccountJson = (value: string) => { const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : '' const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : '' if (!projectId || !clientEmail || !privateKey) { - appStore.showError('Service Account JSON 缺少 project_id、client_email 或 private_key') + appStore.showError(t('admin.accounts.vertexSaJsonMissingFields')) return false } vertexProjectId.value = projectId @@ -4259,7 +4214,7 @@ const applyVertexServiceAccountJson = (value: string) => { vertexServiceAccountJson.value = JSON.stringify(parsed) return true } catch { - appStore.showError('Service Account JSON 格式无效') + appStore.showError(t('admin.accounts.vertexSaJsonInvalid')) return false } } @@ -4406,7 +4361,7 @@ const handleSubmit = async () => { return } if (!vertexLocation.value.trim()) { - appStore.showError('请填写 Vertex location') + appStore.showError(t('admin.accounts.vertexLocationRequired')) return } const credentials: Record = { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 69e2186b..56874474 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -577,9 +577,9 @@ type="text" class="input font-mono" readonly - placeholder="从 JSON 自动读取" + :placeholder="t('admin.accounts.vertexProjectIdPlaceholder')" /> -

Service Account JSON 不在编辑页显示;需要更换 JSON 时请删除账号后重新创建。

+

{{ t('admin.accounts.vertexSaJsonEditHint') }}

@@ -589,7 +589,7 @@ class="input font-mono" > @@ -602,7 +602,182 @@ -

不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。

+

{{ t('admin.accounts.vertexLocationHint') }}

+
+ + + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ + t('admin.accounts.supportsAllModels') + }} +

+
+ + +
+
+

+ + + + {{ t('admin.accounts.mapRequestModels') }} +

+
+ + +
+
+ + + + + + +
+
+ + + + +
+ +
@@ -1959,6 +2134,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' +import { VERTEX_LOCATION_OPTIONS } from '@/constants/account' import { OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, @@ -2030,52 +2206,6 @@ const editBedrockApiKeyValue = ref('') const editVertexProjectId = ref('') const editVertexClientEmail = ref('') const editVertexLocation = ref('us-central1') -const vertexLocationOptions = [ - { - label: 'Common', - options: [ - { value: 'us-central1', label: 'us-central1 (Iowa)' }, - { value: 'global', label: 'global' }, - { value: 'us', label: 'us' }, - { value: 'eu', label: 'eu' } - ] - }, - { - label: 'United States', - options: [ - { value: 'us-east1', label: 'us-east1 (South Carolina)' }, - { value: 'us-east4', label: 'us-east4 (Northern Virginia)' }, - { value: 'us-east5', label: 'us-east5 (Columbus)' }, - { value: 'us-south1', label: 'us-south1 (Dallas)' }, - { value: 'us-west1', label: 'us-west1 (Oregon)' }, - { value: 'us-west4', label: 'us-west4 (Las Vegas)' } - ] - }, - { - label: 'Europe', - options: [ - { value: 'europe-west1', label: 'europe-west1 (Belgium)' }, - { value: 'europe-west2', label: 'europe-west2 (London)' }, - { value: 'europe-west3', label: 'europe-west3 (Frankfurt)' }, - { value: 'europe-west4', label: 'europe-west4 (Netherlands)' }, - { value: 'europe-west6', label: 'europe-west6 (Zurich)' }, - { value: 'europe-west8', label: 'europe-west8 (Milan)' }, - { value: 'europe-west9', label: 'europe-west9 (Paris)' } - ] - }, - { - label: 'Asia Pacific', - options: [ - { value: 'asia-east1', label: 'asia-east1 (Taiwan)' }, - { value: 'asia-east2', label: 'asia-east2 (Hong Kong)' }, - { value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' }, - { value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' }, - { value: 'asia-south1', label: 'asia-south1 (Mumbai)' }, - { value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' }, - { value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' } - ] - } -] as const const isBedrockAPIKeyMode = computed(() => props.account?.type === 'bedrock' && (props.account?.credentials as Record)?.auth_mode === 'apikey' @@ -2564,6 +2694,26 @@ const syncFormFromAccount = (newAccount: Account | null) => { editVertexProjectId.value = (credentials.project_id as string) || '' editVertexClientEmail.value = (credentials.client_email as string) || '' editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1' + + // Load model mappings for service_account + const existingMappings = credentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -3160,20 +3310,20 @@ const handleSubmit = async () => { const newCredentials: Record = { ...currentCredentials } if (!editVertexProjectId.value.trim()) { - appStore.showError('Service Account JSON 缺少 project_id') + appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId')) return } if (!editVertexClientEmail.value.trim()) { - appStore.showError('Service Account JSON 缺少 client_email') + appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail')) return } if (!editVertexLocation.value.trim()) { - appStore.showError('请填写 Vertex location') + appStore.showError(t('admin.accounts.vertexLocationRequired')) return } if (!currentCredentials.service_account_json && !currentCredentials.service_account) { - appStore.showError('请上传 Service Account JSON') + appStore.showError(t('admin.accounts.vertexSaJsonRequired')) return } newCredentials.project_id = editVertexProjectId.value.trim() @@ -3181,6 +3331,14 @@ const handleSubmit = async () => { newCredentials.location = editVertexLocation.value.trim() newCredentials.tier_id = 'vertex' + // Add model mapping if configured + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') if (!applyTempUnschedConfig(newCredentials)) { return diff --git a/frontend/src/constants/account.ts b/frontend/src/constants/account.ts index dcfc7fae..776de4fa 100644 --- a/frontend/src/constants/account.ts +++ b/frontend/src/constants/account.ts @@ -13,3 +13,51 @@ export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOT export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const export const QUOTA_RESET_MODE_FIXED = 'fixed' as const export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED + +/** Vertex AI location options for Service Account accounts */ +export const VERTEX_LOCATION_OPTIONS = [ + { + label: 'Common', + options: [ + { value: 'us-central1', label: 'us-central1 (Iowa)' }, + { value: 'global', label: 'global' }, + { value: 'us', label: 'us' }, + { value: 'eu', label: 'eu' } + ] + }, + { + label: 'United States', + options: [ + { value: 'us-east1', label: 'us-east1 (South Carolina)' }, + { value: 'us-east4', label: 'us-east4 (Northern Virginia)' }, + { value: 'us-east5', label: 'us-east5 (Columbus)' }, + { value: 'us-south1', label: 'us-south1 (Dallas)' }, + { value: 'us-west1', label: 'us-west1 (Oregon)' }, + { value: 'us-west4', label: 'us-west4 (Las Vegas)' } + ] + }, + { + label: 'Europe', + options: [ + { value: 'europe-west1', label: 'europe-west1 (Belgium)' }, + { value: 'europe-west2', label: 'europe-west2 (London)' }, + { value: 'europe-west3', label: 'europe-west3 (Frankfurt)' }, + { value: 'europe-west4', label: 'europe-west4 (Netherlands)' }, + { value: 'europe-west6', label: 'europe-west6 (Zurich)' }, + { value: 'europe-west8', label: 'europe-west8 (Milan)' }, + { value: 'europe-west9', label: 'europe-west9 (Paris)' } + ] + }, + { + label: 'Asia Pacific', + options: [ + { value: 'asia-east1', label: 'asia-east1 (Taiwan)' }, + { value: 'asia-east2', label: 'asia-east2 (Hong Kong)' }, + { value: 'asia-northeast1', label: 'asia-northeast1 (Tokyo)' }, + { value: 'asia-northeast3', label: 'asia-northeast3 (Seoul)' }, + { value: 'asia-south1', label: 'asia-south1 (Mumbai)' }, + { value: 'asia-southeast1', label: 'asia-southeast1 (Singapore)' }, + { value: 'australia-southeast1', label: 'australia-southeast1 (Sydney)' } + ] + } +] as const diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 270cd660..0425955f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2815,6 +2815,26 @@ export default { claudeConsole: 'Claude Console', bedrockLabel: 'AWS Bedrock', bedrockDesc: 'SigV4 / API Key', + vertexLabel: 'Vertex', + vertexDesc: 'Service Account', + vertexAnthropicHint: 'Use a Google Cloud Service Account JSON to call Anthropic Claude via Vertex AI. It is recommended to configure model mapping to map client Claude model names to Vertex model IDs.', + vertexGeminiHint: 'Use a Google Cloud Service Account JSON to access Vertex AI Gemini. It is recommended to place Vertex accounts in a separate group to avoid mixing with AI Studio/Gemini OAuth on the same models.', + vertexSaJsonLabel: 'Service Account JSON', + vertexSaJsonLoaded: 'Service Account JSON loaded', + vertexSaJsonDrop: 'Drop Service Account JSON here', + vertexSaJsonKeyHidden: 'Key content is not displayed in the form.', + vertexSaJsonDropHint: 'Drag a .json file here, or click the button to select one.', + vertexSaJsonSelectBtn: 'Select JSON', + vertexSaJsonUploadHint: 'After uploading or dropping a JSON file, the project_id will be auto-extracted. Key content is only used for account creation.', + vertexSaJsonEditHint: 'Service Account JSON is not shown on the edit page; to change the JSON, delete the account and recreate it.', + vertexProjectIdPlaceholder: 'Auto-extracted from JSON', + vertexLocationHint: 'Available locations vary by Vertex model. Select the default endpoint location for this account.', + vertexLocationRequired: 'Please enter a Vertex location', + vertexSaJsonMissingFields: 'Service Account JSON is missing project_id, client_email, or private_key', + vertexSaJsonMissingProjectId: 'Service Account JSON is missing project_id', + vertexSaJsonMissingClientEmail: 'Service Account JSON is missing client_email', + vertexSaJsonInvalid: 'Service Account JSON format is invalid', + vertexSaJsonRequired: 'Please upload a Service Account JSON', oauthSetupToken: 'OAuth / Setup Token', addMethod: 'Add Method', setupTokenLongLived: 'Setup Token (Long-lived)', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index fdfc9e41..a8656a7b 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2963,6 +2963,26 @@ export default { claudeConsole: 'Claude Console', bedrockLabel: 'AWS Bedrock', bedrockDesc: 'SigV4 / API Key', + vertexLabel: 'Vertex', + vertexDesc: 'Service Account', + vertexAnthropicHint: '使用 Google Cloud Service Account JSON 通过 Vertex AI 调用 Anthropic Claude。建议配置模型映射,将客户端 Claude 模型名映射到 Vertex 模型 ID。', + vertexGeminiHint: '使用 Google Cloud Service Account JSON 访问 Vertex AI Gemini。建议将 Vertex 账号放入独立分组,避免和 AI Studio/Gemini OAuth 同模型混调。', + vertexSaJsonLabel: 'Service Account JSON', + vertexSaJsonLoaded: '已读取 Service Account JSON', + vertexSaJsonDrop: '拖入 Service Account JSON', + vertexSaJsonKeyHidden: '密钥内容不会在表单中显示。', + vertexSaJsonDropHint: '把 .json 文件拖到这里,或点击按钮选择文件。', + vertexSaJsonSelectBtn: '选择 JSON', + vertexSaJsonUploadHint: '上传或拖入 JSON 后会自动读取 project_id,密钥内容仅用于创建账号提交。', + vertexSaJsonEditHint: 'Service Account JSON 不在编辑页显示;需要更换 JSON 时请删除账号后重新创建。', + vertexProjectIdPlaceholder: '从 JSON 自动读取', + vertexLocationHint: '不同 Vertex 模型可用 location 可能不同,这里选择账号默认 endpoint location。', + vertexLocationRequired: '请填写 Vertex location', + vertexSaJsonMissingFields: 'Service Account JSON 缺少 project_id、client_email 或 private_key', + vertexSaJsonMissingProjectId: 'Service Account JSON 缺少 project_id', + vertexSaJsonMissingClientEmail: 'Service Account JSON 缺少 client_email', + vertexSaJsonInvalid: 'Service Account JSON 格式无效', + vertexSaJsonRequired: '请上传 Service Account JSON', oauthSetupToken: 'OAuth / Setup Token', addMethod: '添加方式', setupTokenLongLived: 'Setup Token(长期有效)', From 28dc34b6a38b670920dc9c02819e0fd95ee33037 Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Wed, 29 Apr 2026 17:38:08 +0800 Subject: [PATCH 34/46] fix(openai): avoid inferred WS continuation on explicit tool replay --- .../internal/service/openai_ws_forwarder.go | 26 +- ...penai_ws_forwarder_ingress_session_test.go | 268 ++++++++++++++++++ .../openai_ws_forwarder_ingress_test.go | 108 ++++--- 3 files changed, 356 insertions(+), 46 deletions(-) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index dedbce1e..023217b2 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string func shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled bool, turn int, - hasFunctionCallOutput bool, + signals ToolContinuationSignals, currentPreviousResponseID string, expectedPreviousResponseID string, ) bool { - if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput { return false } if strings.TrimSpace(currentPreviousResponseID) != "" { return false } + if signals.HasFunctionCallOutputMissingCallID { + return false + } + // If the client already sent tool-call context or item_reference anchors, + // treat this as a full replay / self-contained continuation payload rather + // than downgrading it into an inferred delta continuation. + if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs { + return false + } return strings.TrimSpace(expectedPreviousResponseID) != "" } @@ -3179,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( skipBeforeTurn = false currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) - hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + toolSignals := ToolContinuationSignals{ + HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(), + } + if toolSignals.HasFunctionCallOutput { + var currentReqBody map[string]any + if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil { + toolSignals = AnalyzeToolContinuationSignals(currentReqBody) + } + } + hasFunctionCallOutput := toolSignals.HasFunctionCallOutput // store=false + function_call_output 场景必须有续链锚点。 // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 if shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled, turn, - hasFunctionCallOutput, + toolSignals, currentPreviousResponseID, expectedPrev, ) { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 6bf9a9ff..701f069a 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-tool-context", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-item-reference", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { gin.SetMode(gin.TestMode) prevPreflightPingIdle := openAIWSIngressPreflightPingIdle diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index ff35cb01..5bc5db4e 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { name string storeDisabled bool turn int - hasFunctionCallOutput bool + signals ToolContinuationSignals currentPreviousResponse string expectedPrevious string want bool }{ { - name: "infer_when_all_conditions_match", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: true, + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: true, }, { - name: "skip_when_store_enabled", - storeDisabled: false, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_on_first_turn", - storeDisabled: true, - turn: 1, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_without_function_call_output", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: false, - expectedPrevious: "resp_1", - want: false, + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_when_request_already_has_previous_response_id", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, currentPreviousResponse: "resp_client", expectedPrevious: "resp_1", want: false, }, { - name: "skip_when_last_turn_response_id_missing", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "", - want: false, + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "", + want: false, }, { - name: "trim_whitespace_before_judgement", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: " resp_2 ", - want: true, + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: " resp_2 ", + want: true, + }, + { + name: "skip_when_tool_call_context_already_present", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true}, + expectedPrevious: "resp_2", + want: false, + }, + { + name: "skip_when_item_reference_already_covers_all_call_ids", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, + expectedPrevious: "resp_2", + want: false, + }, + { + name: "skip_when_function_call_output_missing_call_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true}, + expectedPrevious: "resp_2", + want: false, }, } @@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { got := shouldInferIngressFunctionCallOutputPreviousResponseID( tt.storeDisabled, tt.turn, - tt.hasFunctionCallOutput, + tt.signals, tt.currentPreviousResponse, tt.expectedPrevious, ) From f7c13af11fa380d47635174645da9b6dd995cda3 Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Wed, 29 Apr 2026 18:02:19 +0800 Subject: [PATCH 35/46] fix: format ingress continuation test --- .../openai_ws_forwarder_ingress_test.go | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index 5bc5db4e..08597f0c 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -238,85 +238,85 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { want bool }{ { - name: "infer_when_all_conditions_match", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, expectedPrevious: "resp_1", - want: true, + want: true, }, { - name: "skip_when_store_enabled", - storeDisabled: false, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, expectedPrevious: "resp_1", - want: false, + want: false, }, { - name: "skip_on_first_turn", - storeDisabled: true, - turn: 1, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, expectedPrevious: "resp_1", - want: false, + want: false, }, { - name: "skip_without_function_call_output", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{}, + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{}, expectedPrevious: "resp_1", - want: false, + want: false, }, { - name: "skip_when_request_already_has_previous_response_id", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, currentPreviousResponse: "resp_client", expectedPrevious: "resp_1", want: false, }, { - name: "skip_when_last_turn_response_id_missing", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, expectedPrevious: "", - want: false, + want: false, }, { - name: "trim_whitespace_before_judgement", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, expectedPrevious: " resp_2 ", - want: true, + want: true, }, { - name: "skip_when_tool_call_context_already_present", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true}, + name: "skip_when_tool_call_context_already_present", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true}, expectedPrevious: "resp_2", - want: false, + want: false, }, { - name: "skip_when_item_reference_already_covers_all_call_ids", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, + name: "skip_when_item_reference_already_covers_all_call_ids", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, expectedPrevious: "resp_2", - want: false, + want: false, }, { - name: "skip_when_function_call_output_missing_call_id", - storeDisabled: true, - turn: 2, - signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true}, + name: "skip_when_function_call_output_missing_call_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true}, expectedPrevious: "resp_2", - want: false, + want: false, }, } From 7ce5b8321573e6628c4449382ab5507ac1ff5aae Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 21:00:30 +0800 Subject: [PATCH 36/46] chore: remove superpowers docs --- ...-27-account-bulk-edit-scope-and-compact.md | 359 ------------------ ...ount-bulk-edit-scope-and-compact-design.md | 233 ------------ 2 files changed, 592 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-27-account-bulk-edit-scope-and-compact.md delete mode 100644 docs/superpowers/specs/2026-04-27-account-bulk-edit-scope-and-compact-design.md diff --git a/docs/superpowers/plans/2026-04-27-account-bulk-edit-scope-and-compact.md b/docs/superpowers/plans/2026-04-27-account-bulk-edit-scope-and-compact.md deleted file mode 100644 index 42b76664..00000000 --- a/docs/superpowers/plans/2026-04-27-account-bulk-edit-scope-and-compact.md +++ /dev/null @@ -1,359 +0,0 @@ -# Account Bulk Edit Scope And Compact Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add filter-result bulk edit to admin accounts, unify the table-level bulk-edit entry, and align OpenAI bulk-edit controls with the existing compact-related single-account settings. - -**Architecture:** Extend the existing `/admin/accounts/bulk-update` flow to accept either explicit account IDs or a server-resolved filter target. Reuse the current account-list filter contract for scope resolution, then update the accounts view and bulk-edit modal so the UI can launch either selected-account edits or current-filter-result edits from one compact dropdown. Keep the existing bulk-edit form, but expand its target contract and OpenAI-specific field coverage. - -**Tech Stack:** Vue 3, TypeScript, Vitest, Gin, Go service/repository layer, existing admin accounts API. - ---- - -### Task 1: Add backend test coverage for filter-target bulk update - -**Files:** -- Modify: `backend/internal/handler/admin/account_handler_mixed_channel_test.go` -- Modify: `backend/internal/service/admin_service_bulk_update_test.go` -- Test: `backend/internal/handler/admin/account_handler_mixed_channel_test.go` -- Test: `backend/internal/service/admin_service_bulk_update_test.go` - -- [ ] **Step 1: Write the failing handler test for filter-target request acceptance** - -```go -func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) { - // add a request body that omits account_ids and submits filters instead - // assert the route does not reject the request as malformed once service stubs are wired -} -``` - -- [ ] **Step 2: Run test to verify it fails** - -Run: `GOCACHE=/tmp/go-build GOMODCACHE=/tmp/go-mod go test ./backend/internal/handler/admin -run TestBulkUpdateAcceptsFilterTargetRequest -count=1` -Expected: FAIL because `BulkUpdateAccountsRequest` does not yet support `filters`. - -- [ ] **Step 3: Write the failing service test for resolving IDs from filters** - -```go -func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) { - // construct BulkUpdateAccountsInput with Filters and no AccountIDs - // stub repository list/search path to return matching IDs - // assert BulkUpdate is called with all matching account IDs -} -``` - -- [ ] **Step 4: Run test to verify it fails** - -Run: `GOCACHE=/tmp/go-build GOMODCACHE=/tmp/go-mod go test ./backend/internal/service -run TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters -count=1` -Expected: FAIL because `BulkUpdateAccountsInput` and service logic only use explicit `AccountIDs`. - -- [ ] **Step 5: Commit** - -```bash -git add backend/internal/handler/admin/account_handler_mixed_channel_test.go backend/internal/service/admin_service_bulk_update_test.go -git commit -m "test: cover filter-target account bulk update" -``` - -### Task 2: Implement backend filter-target bulk update - -**Files:** -- Modify: `backend/internal/handler/admin/account_handler.go` -- Modify: `backend/internal/service/admin_service.go` -- Modify: `backend/internal/repository/account_repo.go` -- Modify: `backend/internal/service/account_service.go` -- Test: `backend/internal/handler/admin/account_handler_mixed_channel_test.go` -- Test: `backend/internal/service/admin_service_bulk_update_test.go` - -- [ ] **Step 1: Implement request structs and validation for filter targets** - -```go -type BulkUpdateAccountFilters struct { - Platform string `json:"platform"` - Type string `json:"type"` - Status string `json:"status"` - Group string `json:"group"` - Search string `json:"search"` - PrivacyMode string `json:"privacy_mode"` -} - -type BulkUpdateAccountsRequest struct { - AccountIDs []int64 `json:"account_ids"` - Filters *BulkUpdateAccountFilters `json:"filters"` - // existing fields remain unchanged -} -``` - -- [ ] **Step 2: Resolve filter targets in the service layer with one canonical path** - -```go -type BulkUpdateAccountsInput struct { - AccountIDs []int64 - Filters *BulkUpdateAccountFilters - // existing fields remain unchanged -} - -if len(input.AccountIDs) == 0 && input.Filters != nil { - ids, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters) - if err != nil { - return nil, err - } - input.AccountIDs = ids -} -``` - -- [ ] **Step 3: Reuse existing account-search/repository logic to resolve all matching IDs** - -```go -func (s *AdminService) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) { - // call the existing repository list/search path with the submitted filters - // page through all matching rows or use a dedicated ID-only query helper - // return unique IDs in stable order -} -``` - -- [ ] **Step 4: Run targeted backend tests** - -Run: `GOCACHE=/tmp/go-build GOMODCACHE=/tmp/go-mod go test ./backend/internal/handler/admin ./backend/internal/service -run 'TestBulkUpdateAcceptsFilterTargetRequest|TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters' -count=1` -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add backend/internal/handler/admin/account_handler.go backend/internal/service/admin_service.go backend/internal/repository/account_repo.go backend/internal/service/account_service.go backend/internal/handler/admin/account_handler_mixed_channel_test.go backend/internal/service/admin_service_bulk_update_test.go -git commit -m "feat: support filter-target account bulk update" -``` - -### Task 3: Add frontend API and modal tests for target scope - -**Files:** -- Modify: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` -- Create: `frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` -- Modify: `frontend/src/api/admin/accounts.ts` -- Test: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` -- Test: `frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` - -- [ ] **Step 1: Write the failing modal test for filter-target payload submission** - -```ts -it('submits bulk edit using current filters when target mode is filtered-results', async () => { - // mount BulkEditAccountModal with targetMode='filtered' - // submit a minimal change - // expect adminAPI.accounts.bulkUpdate to receive { filters: ... } rather than account_ids -}) -``` - -- [ ] **Step 2: Run test to verify it fails** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts -t "filtered-results"` -Expected: FAIL because the modal only accepts `accountIds`. - -- [ ] **Step 3: Write the failing accounts-view test for dropdown launch actions** - -```ts -it('opens bulk edit for current filtered results from the table action dropdown', async () => { - // mount AccountsView with filters set - // click Bulk edit > current filtered results - // assert modal props contain filter target metadata -}) -``` - -- [ ] **Step 4: Run test to verify it fails** - -Run: `pnpm -C frontend test:run src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` -Expected: FAIL because the dropdown action and target scope state do not exist yet. - -- [ ] **Step 5: Commit** - -```bash -git add frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts frontend/src/api/admin/accounts.ts -git commit -m "test: cover account bulk edit target scopes" -``` - -### Task 4: Implement unified frontend bulk-edit target scope flow - -**Files:** -- Modify: `frontend/src/views/admin/AccountsView.vue` -- Modify: `frontend/src/components/admin/account/AccountBulkActionsBar.vue` -- Modify: `frontend/src/components/account/BulkEditAccountModal.vue` -- Modify: `frontend/src/api/admin/accounts.ts` -- Modify: `frontend/src/i18n/locales/zh.ts` -- Modify: `frontend/src/i18n/locales/en.ts` -- Test: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` -- Test: `frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` - -- [ ] **Step 1: Add a typed frontend target contract for bulk edit** - -```ts -export type AccountBulkEditTarget = - | { mode: 'selected'; accountIds: number[]; selectedPlatforms: AccountPlatform[]; selectedTypes: AccountType[] } - | { mode: 'filtered'; filters: AccountListFilters; previewCount: number; selectedPlatforms: AccountPlatform[]; selectedTypes: AccountType[] } -``` - -- [ ] **Step 2: Replace the single selected-row edit button with one dropdown** - -```vue - -``` - -- [ ] **Step 3: Snapshot current filters and preview count when launching filtered mode** - -```ts -const openBulkEditFiltered = async () => { - const filters = toBulkEditFilterSnapshot(params) - const preview = await adminAPI.accounts.list(1, 1, filters) - bulkEditTarget.value = { - mode: 'filtered', - filters, - previewCount: preview.pagination.total, - selectedPlatforms: collectPlatforms(preview.data), - selectedTypes: collectTypes(preview.data) - } - showBulkEdit.value = true -} -``` - -- [ ] **Step 4: Update modal submission to call `bulkUpdate` with either `account_ids` or `filters`** - -```ts -if (props.target.mode === 'selected') { - await adminAPI.accounts.bulkUpdate({ account_ids: props.target.accountIds, ...updates }) -} else { - await adminAPI.accounts.bulkUpdate({ filters: props.target.filters, ...updates }) -} -``` - -- [ ] **Step 5: Run targeted frontend tests** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` -Expected: PASS - -- [ ] **Step 6: Commit** - -```bash -git add frontend/src/views/admin/AccountsView.vue frontend/src/components/admin/account/AccountBulkActionsBar.vue frontend/src/components/account/BulkEditAccountModal.vue frontend/src/api/admin/accounts.ts frontend/src/i18n/locales/zh.ts frontend/src/i18n/locales/en.ts frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts -git commit -m "feat: add filtered-result account bulk edit" -``` - -### Task 5: Add failing tests for missing OpenAI bulk-edit fields - -**Files:** -- Modify: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` -- Test: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` - -- [ ] **Step 1: Write the failing OAuth test for `codex_cli_only`** - -```ts -it('OpenAI OAuth bulk edit can submit codex_cli_only', async () => { - // enable the toggle and submit - // expect extra.codex_cli_only to be sent -}) -``` - -- [ ] **Step 2: Run test to verify it fails** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts -t "codex_cli_only"` -Expected: FAIL because the modal has no such control or payload mapping. - -- [ ] **Step 3: Write the failing API key test for API key WS mode** - -```ts -it('OpenAI API key bulk edit submits API key WS mode fields', async () => { - // enable the API key WS mode selector and submit - // expect openai_apikey_responses_websockets_v2_mode and enabled flag -}) -``` - -- [ ] **Step 4: Run test to verify it fails** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts -t "API key WS mode"` -Expected: FAIL because the modal only submits OAuth WS mode. - -- [ ] **Step 5: Commit** - -```bash -git add frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts -git commit -m "test: cover missing OpenAI bulk edit fields" -``` - -### Task 6: Implement missing OpenAI bulk-edit controls and payload wiring - -**Files:** -- Modify: `frontend/src/components/account/BulkEditAccountModal.vue` -- Modify: `frontend/src/i18n/locales/zh.ts` -- Modify: `frontend/src/i18n/locales/en.ts` -- Test: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` - -- [ ] **Step 1: Add UI controls for OAuth `codex_cli_only` and API key WS mode** - -```vue -
- - -
- -
- -
-``` - -- [ ] **Step 2: Mirror single-account payload semantics in the bulk-edit submit builder** - -```ts -if (enableCodexCLIOnly.value) { - const extra = ensureExtra() - extra.codex_cli_only = codexCLIOnlyEnabled.value -} - -if (enableOpenAIAPIKeyWSMode.value) { - const extra = ensureExtra() - extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) -} -``` - -- [ ] **Step 3: Run focused modal tests** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts` -Expected: PASS - -- [ ] **Step 4: Commit** - -```bash -git add frontend/src/components/account/BulkEditAccountModal.vue frontend/src/i18n/locales/zh.ts frontend/src/i18n/locales/en.ts frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts -git commit -m "feat: align OpenAI bulk edit compact settings" -``` - -### Task 7: Final regression verification - -**Files:** -- Modify: none expected -- Test: `frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts` -- Test: `frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` -- Test: `backend/internal/handler/admin/account_handler_mixed_channel_test.go` -- Test: `backend/internal/service/admin_service_bulk_update_test.go` - -- [ ] **Step 1: Run frontend typecheck** - -Run: `pnpm -C frontend typecheck` -Expected: PASS - -- [ ] **Step 2: Run focused frontend test suite** - -Run: `pnpm -C frontend test:run src/components/account/__tests__/BulkEditAccountModal.spec.ts src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts` -Expected: PASS - -- [ ] **Step 3: Run focused backend test suite** - -Run: `GOCACHE=/tmp/go-build GOMODCACHE=/tmp/go-mod go test ./backend/internal/handler/admin ./backend/internal/service -run 'BulkUpdate|bulk update' -count=1` -Expected: PASS - -- [ ] **Step 4: Commit final integration fixes if needed** - -```bash -git add frontend/src/components/account/BulkEditAccountModal.vue frontend/src/views/admin/AccountsView.vue frontend/src/components/admin/account/AccountBulkActionsBar.vue frontend/src/api/admin/accounts.ts frontend/src/i18n/locales/zh.ts frontend/src/i18n/locales/en.ts backend/internal/handler/admin/account_handler.go backend/internal/service/admin_service.go backend/internal/repository/account_repo.go backend/internal/service/account_service.go frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts frontend/src/views/admin/__tests__/AccountsView.bulkEdit.spec.ts backend/internal/handler/admin/account_handler_mixed_channel_test.go backend/internal/service/admin_service_bulk_update_test.go -git commit -m "feat: finish account bulk edit scope and compact support" -``` diff --git a/docs/superpowers/specs/2026-04-27-account-bulk-edit-scope-and-compact-design.md b/docs/superpowers/specs/2026-04-27-account-bulk-edit-scope-and-compact-design.md deleted file mode 100644 index 3a1dc5ac..00000000 --- a/docs/superpowers/specs/2026-04-27-account-bulk-edit-scope-and-compact-design.md +++ /dev/null @@ -1,233 +0,0 @@ -# Account Bulk Edit Scope And Compact Design - -## Summary - -This change expands admin account bulk edit in two directions: - -1. Add a second bulk-edit target scope based on the current filter result set, so operators do not need to manually select every account. -2. Align OpenAI bulk-edit fields with single-account create/edit for the compact-related settings that are already supported elsewhere. - -The design keeps the existing selected-row workflow intact and adds a unified bulk-edit entry with two explicit actions: - -- `Bulk edit selected accounts` -- `Bulk edit current filtered results` - -`Current filtered results` reuses the existing account-list filters. That means: - -- with no filters, it targets the whole account inventory -- with a group filter, it targets all accounts in that group -- with combined filters, it targets all matching accounts - -## Goals - -- Preserve the current selected-account bulk edit flow. -- Let operators bulk edit the full current filtered result set without manual row selection. -- Show the user the exact target scope before applying changes. -- Reuse the current list filter semantics instead of inventing a separate "all accounts" or "by group" API. -- Add the missing OpenAI bulk-edit fields: - - OAuth `codex_cli_only` - - API key `openai_apikey_responses_websockets_v2_mode` - -## Non-Goals - -- No new standalone "edit all accounts" route that ignores filters. -- No new dedicated "edit group" route separate from list filters. -- No change to the backend merge semantics for other bulk-edit fields. -- No attempt in this change to refactor all account form components into a shared schema system. - -## Current State - -### Bulk edit entry - -The account list currently exposes bulk edit only through selected-row actions. `AccountsView.vue` passes `selIds`, `selPlatforms`, and `selTypes` into `BulkEditAccountModal.vue`. - -### Filter state - -The account page already keeps a central `params` object for current filters and reloads the table from that state. Group filtering already exists in `AccountTableFilters.vue`. - -### Bulk edit payload - -`BulkEditAccountModal.vue` builds a bulk update request around explicit account IDs. - -### OpenAI field gap - -Single-account create/edit already supports: - -- `openai_passthrough` -- OAuth WS mode -- API key WS mode -- OAuth `codex_cli_only` - -Bulk edit currently supports: - -- `openai_passthrough` -- OAuth WS mode only - -That leaves a real capability gap for operators managing large OpenAI account sets. - -## User Experience - -### Entry point - -Use one compact `Bulk edit` dropdown button in the table-level bulk actions area above the grid. - -The dropdown contains: - -- `Bulk edit selected accounts` -- `Bulk edit current filtered results` - -Behavior: - -- If there is no row selection, the `selected accounts` action is disabled. -- `Current filtered results` is always available. -- The existing separate immediate `Edit` action in the selected-row bar is replaced by this unified dropdown to avoid duplicate buttons that mean different scopes. - -### Modal scope messaging - -The bulk edit modal gets a required scope descriptor prop. - -For `selected accounts`: - -- show the existing count-based info banner -- keep using explicit selected account metadata for platform/type compatibility checks - -For `current filtered results`: - -- show a banner stating that edits apply to the current filtered result set -- show the matched account count from a preview query -- show a short summary of active filters when practical, especially group/search/platform/type/status filters - -### Safety - -For filtered-result mode: - -- disable submit if the preview count is `0` -- refresh the target count when the modal opens -- keep the final success toast count aligned with the backend result - -The modal should not silently fall back from filtered mode to selected mode. - -## Backend/API Design - -### Request model - -Extend bulk update to support two target modes: - -- explicit IDs -- filter-based query - -The request shape should keep backward compatibility for the selected-ID path while allowing a filter target. The backend handler can accept a payload that contains either: - -- `account_ids` -- or `filters` - -but not neither. - -The `filters` payload should reuse the existing account-list query semantics already used by `/admin/accounts` and `/admin/accounts/data`, including: - -- `search` -- `platform` -- `type` -- `status` -- `privacy_mode` -- `group` -- existing sort fields may be ignored for mutation targeting if not needed - -### Preview count - -The frontend needs an accurate target count before submit in filtered-result mode. The simplest compatible approach is: - -- call the existing account list endpoint with the current filters and a minimal page size strategy sufficient to obtain total count - -If the current API makes that awkward, add a narrow preview/count helper for bulk edit target resolution. Prefer reusing the existing listing contract first. - -### Target resolution - -For filtered-result mode, the backend must resolve matching account IDs server-side from the submitted filters rather than trusting only currently loaded page data. This is required so filtered-result mode can act on the full result set across pagination. - -### Compatibility metadata - -The frontend still needs platform/type compatibility to determine which fields to show. For filtered-result mode, derive this from the preview result set returned from the same query used to show count. If the preview spans mixed incompatible account types, show the same warnings/conditional UI that selected mode already uses. - -## Frontend Design - -### Accounts view - -`AccountsView.vue` will: - -- replace the direct selected-only bulk edit trigger with a dropdown action model -- keep a reactive description of the pending bulk edit scope -- pass either selected IDs or current filter params into the modal - -The "current filtered results" action uses the live `params` object snapshot at open time, not a mutable live subscription while the modal is already open. - -### Bulk edit modal - -`BulkEditAccountModal.vue` will accept a richer target contract, for example: - -- target mode -- selected IDs or filter snapshot -- preview count -- preview platform/type coverage if needed - -The modal remains one form; only the scope banner and submission target differ. - -### OpenAI field alignment - -Add the missing OpenAI controls to bulk edit: - -- OAuth `codex_cli_only` -- API key WS mode selector - -Rules: - -- OAuth accounts show OAuth WS mode and `codex_cli_only` -- API key accounts show API key WS mode -- mixed OpenAI OAuth/API key selections continue to show only fields that are safe for the entire target set - -The payload builder must write: - -- `extra.codex_cli_only` -- `extra.openai_apikey_responses_websockets_v2_mode` -- `extra.openai_apikey_responses_websockets_v2_enabled` - -with the same enable/disable semantics already used by single-account forms. - -## Testing Strategy - -### Frontend tests - -Add or extend tests for: - -- bulk edit dropdown actions in the accounts view -- selected-account mode still calling bulk update by IDs -- filtered-result mode calling bulk update with filter target -- filtered-result mode showing preview count and blocking submit on zero matches -- OAuth bulk edit supporting `codex_cli_only` -- API key bulk edit supporting API key WS mode -- no regression for existing passthrough and OAuth WS mode tests - -### Backend tests - -Add or extend tests for: - -- bulk update request validation for IDs vs filters -- filtered-result mode resolving all matching accounts across pagination semantics -- mixed-channel risk checks still running for filter-target updates if applicable -- backward compatibility for the existing selected-ID request path - -## Risks - -- Filter semantics can drift if bulk edit reimplements list-filter parsing differently from the listing endpoints. -- Filtered-result mode can surprise users if the active scope is not shown clearly enough. -- Large filtered updates may affect many rows; success/error messaging must stay explicit. - -## Recommendation - -Implement this as a targeted extension of the existing bulk edit flow: - -- unify the entry point in the table action area -- add filter-target bulk update support -- align the missing OpenAI compact-related fields - -This keeps the mental model simple and solves the large-account-management pain without introducing a second parallel batch-edit system. From 5e54d492be59fb4254427704c68cb39d2fbe1616 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 21:35:18 +0800 Subject: [PATCH 37/46] fix(lint): check type assertion error in codex transform test The errcheck linter flagged an unchecked type assertion on item["type"].(string). Use the two-value form with require.True to satisfy the linter and fail clearly on unexpected types. --- backend/internal/service/openai_codex_transform_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index c6f147d8..7ab6bfc0 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1176,7 +1176,9 @@ func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *t for _, raw := range filtered { item, ok := raw.(map[string]any) require.True(t, ok) - gotTypes[item["type"].(string)]++ + typ, ok := item["type"].(string) + require.True(t, ok) + gotTypes[typ]++ } require.Equal(t, 1, gotTypes["message"]) require.Equal(t, 1, gotTypes["function_call"]) From 40feb86ba4657d5cb4ef3077d2a4a38abbd8b395 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 22:11:45 +0800 Subject: [PATCH 38/46] fix(httputil): add decompression bomb guard and fix errcheck lint --- backend/internal/pkg/httputil/body.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go index 31bba8c5..cee12948 100644 --- a/backend/internal/pkg/httputil/body.go +++ b/backend/internal/pkg/httputil/body.go @@ -16,6 +16,9 @@ import ( const ( requestBodyReadInitCap = 512 requestBodyReadMaxInitCap = 1 << 20 + // maxDecompressedBodySize limits the decompressed request body to 64 MB + // to prevent decompression bomb attacks. + maxDecompressedBodySize = 64 << 20 ) // ReadRequestBodyWithPrealloc reads request body with preallocated buffer based @@ -69,21 +72,21 @@ func decompressRequestBody(encoding string, raw []byte) ([]byte, error) { return nil, err } defer dec.Close() - return io.ReadAll(dec) + return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize)) case "gzip", "x-gzip": gr, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return nil, err } - defer gr.Close() - return io.ReadAll(gr) + defer func() { _ = gr.Close() }() + return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize)) case "deflate": zr, err := zlib.NewReader(bytes.NewReader(raw)) if err != nil { return nil, err } - defer zr.Close() - return io.ReadAll(zr) + defer func() { _ = zr.Close() }() + return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize)) default: return nil, errors.New("unsupported Content-Encoding") } From 8bf2a7b88a14188acf34b6e0693d81dfc42559b7 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 29 Apr 2026 22:48:39 +0800 Subject: [PATCH 39/46] fix(scheduler): resolve SetSnapshot race conditions and remove usage throttle Backend: Fix three race conditions in SetSnapshot that caused account scheduling anomalies and broken sticky sessions: - Use Lua CAS script for atomic version activation, preventing version rollback when concurrent goroutines write snapshots simultaneously - Add UnlockBucket to release rebuild lock immediately after completion instead of waiting 30s TTL expiry - Replace immediate DEL of old snapshots with 60s EXPIRE grace period, preventing readers from hitting empty ZRANGE during version switches Frontend: Remove serial queue throttle (1-2s delay per request) from usage loading since backend now uses passive sampling. All usage requests execute immediately in parallel. --- ...eway_handler_warmup_intercept_unit_test.go | 3 + .../account_repo_integration_test.go | 4 + .../internal/repository/scheduler_cache.go | 83 +++++++++++++++--- backend/internal/service/scheduler_cache.go | 2 + .../scheduler_snapshot_hydration_test.go | 4 + .../service/scheduler_snapshot_service.go | 3 + frontend/src/utils/usageLoadQueue.ts | 87 ++----------------- 7 files changed, 91 insertions(+), 95 deletions(-) diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 71030140..57554cf9 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time. func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { return true, nil } +func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error { + return nil +} func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index b249bb61..d1cea9eb 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi return true, nil } +func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error { + return nil +} + func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index add0e501..8e1f9f56 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -24,6 +24,49 @@ const ( defaultSchedulerSnapshotMGetChunkSize = 128 defaultSchedulerSnapshotWriteChunkSize = 256 + + // snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。 + // 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。 + snapshotGraceTTLSeconds = 60 +) + +var ( + // activateSnapshotScript 原子 CAS 切换快照版本。 + // 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。 + // 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。 + // + // KEYS[1] = activeKey (sched:active:{bucket}) + // KEYS[2] = readyKey (sched:ready:{bucket}) + // KEYS[3] = bucketSetKey (sched:buckets) + // KEYS[4] = snapshotKey (新写入的快照 key) + // ARGV[1] = 新版本号字符串 + // ARGV[2] = bucket 字符串 (用于 SADD) + // ARGV[3] = 快照 key 前缀 (用于构造旧快照 key) + // ARGV[4] = 宽限期 TTL 秒数 + // + // 返回 1 = 已激活, 0 = 版本过旧未激活 + activateSnapshotScript = redis.NewScript(` +local currentActive = redis.call('GET', KEYS[1]) +local newVersion = tonumber(ARGV[1]) + +if currentActive ~= false then + local curVersion = tonumber(currentActive) + if curVersion and newVersion < curVersion then + redis.call('DEL', KEYS[4]) + return 0 + end +end + +redis.call('SET', KEYS[1], ARGV[1]) +redis.call('SET', KEYS[2], '1') +redis.call('SADD', KEYS[3], ARGV[2]) + +if currentActive ~= false and currentActive ~= ARGV[1] then + redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4])) +end + +return 1 +`) ) type schedulerCache struct { @@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul } func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { - activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) - oldActive, _ := c.rdb.Get(ctx, activeKey).Result() - + // Phase 1: 分配新版本号并写入快照数据。 + // INCR 保证每个调用方获得唯一递增版本号。 + // 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。 versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket) version, err := c.rdb.Incr(ctx, versionKey).Result() if err != nil { @@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul return err } - pipe := c.rdb.Pipeline() if len(accounts) > 0 { // 使用序号作为 score,保持数据库返回的排序语义。 members := make([]redis.Z, 0, len(accounts)) @@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul Member: strconv.FormatInt(account.ID, 10), }) } + pipe := c.rdb.Pipeline() for start := 0; start < len(members); start += c.writeChunkSize { end := start + c.writeChunkSize if end > len(members) { @@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul } pipe.ZAdd(ctx, snapshotKey, members[start:end]...) } - } else { - pipe.Del(ctx, snapshotKey) - } - pipe.Set(ctx, activeKey, versionStr, 0) - pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0) - pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String()) - if _, err := pipe.Exec(ctx); err != nil { - return err + if _, err := pipe.Exec(ctx); err != nil { + return err + } } - if oldActive != "" && oldActive != versionStr { - _ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err() + // Phase 2: 原子 CAS 激活版本。 + // Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针, + // 防止并发写入导致版本回滚。 + // 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。 + activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) + readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket) + snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode) + + keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey} + args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds} + + _, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result() + if err != nil { + return err } return nil @@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result() } +func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error { + key := schedulerBucketKey(schedulerLockPrefix, bucket) + return c.rdb.Del(ctx, key).Err() +} + func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result() if err != nil { diff --git a/backend/internal/service/scheduler_cache.go b/backend/internal/service/scheduler_cache.go index f36135e0..f9794c82 100644 --- a/backend/internal/service/scheduler_cache.go +++ b/backend/internal/service/scheduler_cache.go @@ -59,6 +59,8 @@ type SchedulerCache interface { UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error // TryLockBucket 尝试获取分桶重建锁。 TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) + // UnlockBucket 释放分桶重建锁。 + UnlockBucket(ctx context.Context, bucket SchedulerBucket) error // ListBuckets 返回已注册的分桶集合。 ListBuckets(ctx context.Context) ([]SchedulerBucket, error) // GetOutboxWatermark 读取 outbox 水位。 diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go index 5c0b289b..0b32c2ad 100644 --- a/backend/internal/service/scheduler_snapshot_hydration_test.go +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched return true, nil } +func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error { + return nil +} + func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 62b6993d..a68cdf0c 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch if !ok { return nil } + defer func() { + _ = s.cache.UnlockBucket(ctx, bucket) + }() rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/frontend/src/utils/usageLoadQueue.ts b/frontend/src/utils/usageLoadQueue.ts index 7bea5679..042b1240 100644 --- a/frontend/src/utils/usageLoadQueue.ts +++ b/frontend/src/utils/usageLoadQueue.ts @@ -1,93 +1,18 @@ /** - * Usage request scheduler — throttles Anthropic API calls by proxy exit. + * Usage request scheduler. * - * Anthropic OAuth/setup-token accounts sharing the same proxy exit are placed - * into a serial queue with a random 1–2s delay between requests, preventing - * upstream 429 rate-limit errors. - * - * Proxy identity = host:port:username — two proxy records pointing to the - * same exit share a single queue. Accounts without a proxy go into a - * "direct" queue. - * - * All other platforms bypass the queue and execute immediately. + * All platforms execute immediately without queuing — the backend uses + * passive sampling so upstream 429 rate-limit errors are no longer a concern. */ import type { Account } from '@/types' -const GROUP_DELAY_MIN_MS = 1000 -const GROUP_DELAY_MAX_MS = 2000 - -type Task = { - fn: () => Promise - resolve: (value: T) => void - reject: (reason: unknown) => void -} - -const queues = new Map[]>() -const running = new Set() - -/** Whether this account needs throttled queuing. */ -function needsThrottle(account: Account): boolean { - return ( - account.platform === 'anthropic' && - (account.type === 'oauth' || account.type === 'setup-token') - ) -} - -/** Build a queue key from proxy connection details. */ -function buildGroupKey(account: Account): string { - const proxy = account.proxy - const proxyIdentity = proxy - ? `${proxy.host}:${proxy.port}:${proxy.username || ''}` - : 'direct' - return `anthropic:${proxyIdentity}` -} - -async function drain(groupKey: string) { - if (running.has(groupKey)) return - running.add(groupKey) - - const queue = queues.get(groupKey) - while (queue && queue.length > 0) { - const task = queue.shift()! - try { - const result = await task.fn() - task.resolve(result) - } catch (err) { - task.reject(err) - } - if (queue.length > 0) { - const jitter = GROUP_DELAY_MIN_MS + Math.random() * (GROUP_DELAY_MAX_MS - GROUP_DELAY_MIN_MS) - await new Promise((r) => setTimeout(r, jitter)) - } - } - - running.delete(groupKey) - queues.delete(groupKey) -} - /** - * Schedule a usage fetch. Anthropic accounts are queued by proxy exit; - * all other platforms execute immediately. + * Schedule a usage fetch. All requests execute immediately. */ export function enqueueUsageRequest( - account: Account, + _account: Account, fn: () => Promise ): Promise { - // Non-Anthropic → fire immediately, no queuing - if (!needsThrottle(account)) { - return fn() - } - - const key = buildGroupKey(account) - - return new Promise((resolve, reject) => { - let queue = queues.get(key) - if (!queue) { - queue = [] - queues.set(key, queue) - } - queue.push({ fn, resolve, reject } as Task) - drain(key) - }) + return fn() } From 8ad099baa6057f0dfed32ded1f04fc5ea5a38717 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:08:59 +0000 Subject: [PATCH 40/46] chore: sync VERSION to 0.1.120 [skip ci] --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 841597f0..27f3bc3e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.119 +0.1.120 From f084d30d6530cc76d76c9ff6a5cd20bb6628988e Mon Sep 17 00:00:00 2001 From: DaydreamCoding Date: Thu, 16 Apr 2026 21:23:19 +0800 Subject: [PATCH 41/46] =?UTF-8?q?fix:=20=E6=81=A2=E5=A4=8D=E8=A1=A8?= =?UTF-8?q?=E6=A0=BC=E5=88=86=E9=A1=B5=E5=A4=A7=E5=B0=8F=20localStorage=20?= =?UTF-8?q?=E6=8C=81=E4=B9=85=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - usePersistedPageSize: 恢复 localStorage 读写,以系统配置为 fallback - useTableLoader: handlePageSizeChange 时写入 localStorage - Pagination.vue: handlePageSizeChange 时写入 localStorage Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/src/components/common/Pagination.vue | 2 ++ .../src/composables/usePersistedPageSize.ts | 28 ++++++++++++++++--- frontend/src/composables/useTableLoader.ts | 3 +- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/frontend/src/components/common/Pagination.vue b/frontend/src/components/common/Pagination.vue index 2bfc6872..9b4ac200 100644 --- a/frontend/src/components/common/Pagination.vue +++ b/frontend/src/components/common/Pagination.vue @@ -123,6 +123,7 @@ import { useI18n } from 'vue-i18n' import Icon from '@/components/icons/Icon.vue' import Select from './Select.vue' import { getConfiguredTablePageSizeOptions, normalizeTablePageSize } from '@/utils/tablePreferences' +import { setPersistedPageSize } from '@/composables/usePersistedPageSize' const { t } = useI18n() @@ -224,6 +225,7 @@ const goToPage = (newPage: number) => { const handlePageSizeChange = (value: string | number | boolean | null) => { if (value === null || typeof value === 'boolean') return const newPageSize = normalizeTablePageSize(typeof value === 'string' ? parseInt(value, 10) : value) + setPersistedPageSize(newPageSize) emit('update:pageSize', newPageSize) } diff --git a/frontend/src/composables/usePersistedPageSize.ts b/frontend/src/composables/usePersistedPageSize.ts index 366619ea..972373d1 100644 --- a/frontend/src/composables/usePersistedPageSize.ts +++ b/frontend/src/composables/usePersistedPageSize.ts @@ -1,9 +1,29 @@ import { getConfiguredTableDefaultPageSize, normalizeTablePageSize } from '@/utils/tablePreferences' -/** - * 读取当前系统配置的表格默认每页条数。 - * 不再使用本地持久化缓存,所有页面统一以通用表格设置为准。 - */ +const STORAGE_KEY = 'table-page-size' + export function getPersistedPageSize(fallback = getConfiguredTableDefaultPageSize()): number { + if (typeof window !== 'undefined') { + try { + const stored = window.localStorage.getItem(STORAGE_KEY) + if (stored !== null) { + const parsed = Number(stored) + if (Number.isFinite(parsed)) { + return normalizeTablePageSize(parsed) + } + } + } catch (error) { + console.warn('Failed to read persisted page size:', error) + } + } return normalizeTablePageSize(getConfiguredTableDefaultPageSize() || fallback) } + +export function setPersistedPageSize(size: number): void { + if (typeof window === 'undefined') return + try { + window.localStorage.setItem(STORAGE_KEY, String(size)) + } catch (error) { + console.warn('Failed to persist page size:', error) + } +} diff --git a/frontend/src/composables/useTableLoader.ts b/frontend/src/composables/useTableLoader.ts index c288f42e..67c1dcdb 100644 --- a/frontend/src/composables/useTableLoader.ts +++ b/frontend/src/composables/useTableLoader.ts @@ -1,7 +1,7 @@ import { ref, reactive, onUnmounted, toRaw } from 'vue' import { useDebounceFn } from '@vueuse/core' import type { BasePaginationResponse, FetchOptions } from '@/types' -import { getPersistedPageSize } from './usePersistedPageSize' +import { getPersistedPageSize, setPersistedPageSize } from './usePersistedPageSize' interface PaginationState { page: number @@ -88,6 +88,7 @@ export function useTableLoader>(options: TableL const handlePageSizeChange = (size: number) => { pagination.page_size = size pagination.page = 1 + setPersistedPageSize(size) load() } From 733627cf9d3fe337f24f8f658979646a52f3a5ba Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 30 Apr 2026 11:38:11 +0800 Subject: [PATCH 42/46] fix: improve sticky session scheduling --- backend/internal/handler/gateway_handler.go | 44 ++++++ .../internal/repository/scheduler_cache.go | 58 +++++++ .../scheduler_cache_integration_test.go | 16 ++ .../repository/scheduler_cache_unit_test.go | 40 +++++ backend/internal/service/gateway_service.go | 142 ++++++++++++++++-- 5 files changed, 290 insertions(+), 10 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ef532559..7b082b07 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -262,6 +262,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + // [DEBUG-STICKY] 打印会话 hash 生成结果 + reqLog.Info("sticky.session_hash_generated", + zap.String("session_hash", sessionHash), + zap.String("metadata_user_id_raw", parsedReq.MetadataUserID), + ) + // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 platform := "" if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { @@ -278,6 +284,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + // [DEBUG-STICKY] 打印粘性会话查询结果 + reqLog.Info("sticky.cache_lookup", + zap.String("session_key", sessionKey), + zap.Int64("bound_account_id", sessionBoundAccountID), + ) if sessionBoundAccountID > 0 { prefetchedGroupID := int64(0) if apiKey.GroupID != nil { @@ -286,6 +297,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } + } else { + reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash)) } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 @@ -536,6 +549,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 + reqLog.Info("sticky.selecting_account", + zap.String("session_key", sessionKey), + zap.Int64("sticky_bound_account_id", sessionBoundAccountID), + zap.Bool("has_bound_session", hasBoundSession), + zap.Int("failed_account_count", len(fs.FailedAccountIDs)), + ) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { @@ -569,6 +588,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) + // [DEBUG-STICKY] 打印账号选择结果 + reqLog.Info("sticky.account_selected", + zap.Int64("selected_account_id", account.ID), + zap.String("account_name", account.Name), + zap.Bool("slot_acquired", selection.Acquired), + zap.Bool("has_wait_plan", selection.WaitPlan != nil), + zap.Int64("sticky_bound_account_id", sessionBoundAccountID), + zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID), + ) + // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) @@ -635,6 +664,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // Slot acquired: no longer waiting in queue. releaseWait() + reqLog.Info("sticky.bind_after_wait", + zap.String("session_key", sessionKey), + zap.Int64("account_id", account.ID), + ) if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } @@ -829,6 +862,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } + // 绑定粘性会话(成功转发后绑定/刷新) + // - 无现有绑定(首次请求):创建绑定 + // - 选中账号与粘性账号一致:刷新 TTL + // - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定, + // 下次请求粘性账号恢复后仍可命中 + if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) { + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index 8e1f9f56..590ddaa3 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -449,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account { SessionWindowStart: account.SessionWindowStart, SessionWindowEnd: account.SessionWindowEnd, SessionWindowStatus: account.SessionWindowStatus, + AccountGroups: filterSchedulerAccountGroups(account.AccountGroups), + GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups), Credentials: filterSchedulerCredentials(account.Credentials), Extra: filterSchedulerExtra(account.Extra), } } +func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup { + if len(accountGroups) == 0 { + return nil + } + + filtered := make([]service.AccountGroup, 0, len(accountGroups)) + for _, ag := range accountGroups { + if ag.GroupID <= 0 { + continue + } + filtered = append(filtered, service.AccountGroup{ + AccountID: ag.AccountID, + GroupID: ag.GroupID, + Priority: ag.Priority, + CreatedAt: ag.CreatedAt, + }) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 { + if len(groupIDs) == 0 && len(accountGroups) == 0 { + return nil + } + + seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups)) + filtered := make([]int64, 0, len(groupIDs)+len(accountGroups)) + for _, id := range groupIDs { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + filtered = append(filtered, id) + } + for _, ag := range accountGroups { + if ag.GroupID <= 0 { + continue + } + if _, ok := seen[ag.GroupID]; ok { + continue + } + seen[ag.GroupID] = struct{}{} + filtered = append(filtered, ag.GroupID) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + func filterSchedulerCredentials(credentials map[string]any) map[string]any { if len(credentials) == 0 { return nil diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go index 134a6a07..948c2c73 100644 --- a/backend/internal/repository/scheduler_cache_integration_test.go +++ b/backend/internal/repository/scheduler_cache_integration_test.go @@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) SessionWindowStart: &now, SessionWindowEnd: &windowEnd, SessionWindowStatus: "active", + GroupIDs: []int64{bucket.GroupID}, + AccountGroups: []service.AccountGroup{ + { + AccountID: 101, + GroupID: bucket.GroupID, + Priority: 5, + Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"}, + }, + }, } require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account})) @@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) require.Equal(t, 4, got.GetMaxSessions()) require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes()) require.Nil(t, got.Extra["unused_large_field"]) + require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs) + require.Len(t, got.AccountGroups, 1) + require.Equal(t, account.ID, got.AccountGroups[0].AccountID) + require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID) + require.Nil(t, got.AccountGroups[0].Group) full, err := cache.GetAccount(ctx, account.ID) require.NoError(t, err) require.NotNil(t, full) require.Equal(t, "secret-access-token", full.GetCredential("access_token")) require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob")) + require.Len(t, full.AccountGroups, 1) + require.NotNil(t, full.AccountGroups[0].Group) } diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go index bcfd0e7a..33f3b581 100644 --- a/backend/internal/repository/scheduler_cache_unit_test.go +++ b/backend/internal/repository/scheduler_cache_unit_test.go @@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) { require.Equal(t, true, got.Extra["mixed_scheduling"]) require.Nil(t, got.Extra["unused_large_field"]) } + +func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) { + account := service.Account{ + ID: 42, + Platform: service.PlatformAnthropic, + GroupIDs: []int64{7, 9, 7, 0}, + AccountGroups: []service.AccountGroup{ + { + AccountID: 42, + GroupID: 7, + Priority: 2, + Account: &service.Account{ID: 42, Name: "drop-from-metadata"}, + Group: &service.Group{ID: 7, Name: "drop-from-metadata"}, + }, + { + AccountID: 42, + GroupID: 11, + Priority: 3, + Group: &service.Group{ID: 11, Name: "drop-from-metadata"}, + }, + { + AccountID: 42, + GroupID: 0, + Priority: 4, + }, + }, + } + + got := buildSchedulerMetadataAccount(account) + + require.Equal(t, []int64{7, 9, 11}, got.GroupIDs) + require.Len(t, got.AccountGroups, 2) + require.Equal(t, int64(42), got.AccountGroups[0].AccountID) + require.Equal(t, int64(7), got.AccountGroups[0].GroupID) + require.Equal(t, 2, got.AccountGroups[0].Priority) + require.Nil(t, got.AccountGroups[0].Account) + require.Nil(t, got.AccountGroups[0].Group) + require.Equal(t, int64(11), got.AccountGroups[1].GroupID) + require.Nil(t, got.Groups) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f3cae916..d1f12009 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -654,15 +654,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { - if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + uid := ParseMetadataUserID(parsed.MetadataUserID) + if uid != nil && uid.SessionID != "" { + slog.Info("sticky.hash_source", + "source", "metadata_user_id", + "session_id", uid.SessionID, + "device_id", uid.DeviceID, + "is_new_format", uid.IsNewFormat, + ) return uid.SessionID } + slog.Info("sticky.hash_metadata_parse_failed", + "metadata_user_id", parsed.MetadataUserID, + "parsed_nil", uid == nil, + ) } // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 cacheableContent := s.extractCacheableContent(parsed) if cacheableContent != "" { - return s.hashContent(cacheableContent) + hash := s.hashContent(cacheableContent) + slog.Info("sticky.hash_source", + "source", "cacheable_content", + "hash", hash, + ) + return hash } // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 @@ -702,7 +718,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } } if combined.Len() > 0 { - return s.hashContent(combined.String()) + hash := s.hashContent(combined.String()) + slog.Info("sticky.hash_source", + "source", "message_content_fallback", + "hash", hash, + "content_len", combined.Len(), + ) + return hash } return "" @@ -1406,14 +1428,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } var stickyAccountID int64 + var stickySource string if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch + stickySource = "prefetch" } else if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID + stickySource = "cache" } } + // [DEBUG-STICKY] 调度器入口日志 + slog.Info("sticky.scheduler_entry", + "group_id", derefGroupID(groupID), + "session_hash", shortSessionHash(sessionHash), + "sticky_account_id", stickyAccountID, + "sticky_source", stickySource, + "model", requestedModel, + "load_batch", cfg.LoadBatchEnabled, + "has_concurrency_svc", s.concurrencyService != nil, + "excluded_count", len(excludedIDs), + ) + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1589,6 +1626,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 if sessionHash != "" && stickyAccountID > 0 { + slog.Debug("sticky.layer1_5_checking", + "sticky_account_id", stickyAccountID, + "in_routing_list", containsInt64(routingAccountIDs, stickyAccountID), + "is_excluded", isExcluded(stickyAccountID), + "in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(), + "session", shortSessionHash(sessionHash), + ) if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { @@ -1612,6 +1656,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { + slog.Debug("sticky.layer1_5_hit", + "account_id", stickyAccountID, + "session", shortSessionHash(sessionHash), + "result", "slot_acquired", + ) if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } @@ -1762,27 +1811,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 检查账户是否需要清理粘性会话绑定 clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { + slog.Debug("sticky.layer1_5_no_routing_clear", + "account_id", accountID, + "reason", "should_clear_sticky_session", + "session", shortSessionHash(sessionHash), + ) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && - s.isAccountAllowedForPlatform(account, platform, useMixed) && - (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && - s.isAccountSchedulableForQuota(account) && - s.isAccountSchedulableForWindowCost(ctx, account, true) && - s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 + // 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的 + // accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存 + // 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。 + platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed) + modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) + modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) + quotaOK := s.isAccountSchedulableForQuota(account) + windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true) + rpmOK := s.isAccountSchedulableForRPM(ctx, account, true) + schedulable := s.isAccountSchedulableForSelection(account) + + slog.Debug("sticky.layer1_5_no_routing_checks", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "clear_sticky", clearSticky, + "schedulable", schedulable, + "platform_ok", platformOK, + "model_supported", modelSupported, + "model_schedulable", modelSchedulable, + "quota_ok", quotaOK, + "window_cost_ok", windowCostOK, + "rpm_ok", rpmOK, + ) + + if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "session_limit", + "session", shortSessionHash(sessionHash), + ) } else { + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "result", "slot_acquired", + ) if s.cache != nil { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) } return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } + } else { + slog.Debug("sticky.layer1_5_no_routing_slot_busy", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + ) } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) @@ -1791,6 +1878,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 } else { + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "result", "wait_plan", + ) return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ AccountID: accountID, MaxConcurrency: account.Concurrency, @@ -1799,12 +1891,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro }) } } + } else if !clearSticky { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "gate_check_failed", + "session", shortSessionHash(sessionHash), + ) } + } else { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "account_not_in_map", + "session", shortSessionHash(sessionHash), + ) } } + } else if len(routingAccountIDs) == 0 && sessionHash != "" { + slog.Debug("sticky.layer1_5_no_routing_skip", + "sticky_account_id", stickyAccountID, + "is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(), + "session", shortSessionHash(sessionHash), + "reason", func() string { + if stickyAccountID == 0 { + return "no_sticky_binding" + } + return "sticky_account_excluded" + }(), + ) } // ============ Layer 2: 负载感知选择 ============ + slog.Debug("sticky.layer2_fallback", + "session", shortSessionHash(sessionHash), + "sticky_account_id", stickyAccountID, + "reason", "sticky_not_used_falling_back_to_load_balance", + "total_accounts", len(accounts), + ) candidates := make([]*Account, 0, len(accounts)) for i := range accounts { acc := &accounts[i] From 094e1171efb4c1886d11d7bbf71088a043b7a4aa Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 30 Apr 2026 12:02:08 +0800 Subject: [PATCH 43/46] fix(openai): infer previous response for item references --- .gitignore | 2 +- backend/internal/service/openai_ws_forwarder.go | 10 ++++++---- .../openai_ws_forwarder_ingress_session_test.go | 4 ++-- .../service/openai_ws_forwarder_ingress_test.go | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index a61f406d..cf251f07 100644 --- a/.gitignore +++ b/.gitignore @@ -122,7 +122,7 @@ scripts .code-review-state #openspec/ code-reviews/ -#AGENTS.md +AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 023217b2..d1386b1b 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1379,10 +1379,12 @@ func shouldInferIngressFunctionCallOutputPreviousResponseID( if signals.HasFunctionCallOutputMissingCallID { return false } - // If the client already sent tool-call context or item_reference anchors, - // treat this as a full replay / self-contained continuation payload rather - // than downgrading it into an inferred delta continuation. - if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs { + // If the client already sent the actual tool-call context, treat this as + // a full replay / self-contained continuation payload rather than + // downgrading it into an inferred delta continuation. item_reference alone + // is not enough on the store=false WS path: it still needs a valid prior + // response anchor so upstream can resolve the referenced function_call. + if signals.HasToolCallContext { return false } return strings.TrimSpace(expectedPreviousResponseID) != "" diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 701f069a..30fd4142 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1488,7 +1488,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id") } -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) { +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{} @@ -1619,7 +1619,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.Equal(t, 1, captureDialer.DialCount()) require.Len(t, captureConn.writes, 2) - require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id") + require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID") } func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index 08597f0c..c735f50a 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -303,12 +303,12 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { want: false, }, { - name: "skip_when_item_reference_already_covers_all_call_ids", + name: "infer_when_only_item_reference_covers_call_ids", storeDisabled: true, turn: 2, signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, expectedPrevious: "resp_2", - want: false, + want: true, }, { name: "skip_when_function_call_output_missing_call_id", From 73b872998e2e44dc8c11e6aec4d55a34fa5badeb Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 30 Apr 2026 13:38:22 +0800 Subject: [PATCH 44/46] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20Anthropic=20?= =?UTF-8?q?=E7=BC=93=E5=AD=98=20TTL=20=E6=B3=A8=E5=85=A5=E5=BC=80=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/handler/admin/setting_handler.go | 18 ++- backend/internal/handler/dto/settings.go | 7 +- backend/internal/service/domain_constants.go | 2 + .../service/gateway_body_order_test.go | 135 ++++++++++++++++++ backend/internal/service/gateway_service.go | 122 ++++++++++++++-- backend/internal/service/setting_service.go | 102 ++++++++----- backend/internal/service/settings_view.go | 7 +- frontend/src/api/admin/settings.ts | 2 + frontend/src/i18n/locales/en.ts | 2 + frontend/src/i18n/locales/zh.ts | 2 + frontend/src/views/admin/SettingsView.vue | 28 ++++ .../admin/__tests__/SettingsView.spec.ts | 21 +++ 12 files changed, 394 insertions(+), 54 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index d6580191..59f4fe85 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -209,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, + EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, @@ -441,9 +442,10 @@ type UpdateSettingsRequest struct { BackendModeEnabled bool `json:"backend_mode_enabled"` // Gateway forwarding behavior - EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` - EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` - EnableCCHSigning *bool `json:"enable_cch_signing"` + EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` + EnableCCHSigning *bool `json:"enable_cch_signing"` + EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"` // Payment visible method routing PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` @@ -1273,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + EnableAnthropicCacheTTL1hInjection: func() bool { + if req.EnableAnthropicCacheTTL1hInjection != nil { + return *req.EnableAnthropicCacheTTL1hInjection + } + return previousSettings.EnableAnthropicCacheTTL1hInjection + }(), PaymentVisibleMethodAlipaySource: func() string { if req.PaymentVisibleMethodAlipaySource != nil { return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) @@ -1570,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection, PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, @@ -1949,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection { + changed = append(changed, "enable_anthropic_cache_ttl_1h_injection") + } if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { changed = append(changed, "payment_visible_method_alipay_source") } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index b865d703..492be170 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -142,9 +142,10 @@ type SystemSettings struct { BackendModeEnabled bool `json:"backend_mode_enabled"` // Gateway forwarding behavior - EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` - EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` - EnableCCHSigning bool `json:"enable_cch_signing"` + EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` + EnableCCHSigning bool `json:"enable_cch_signing"` + EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"` // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index bddcf6ab..bb32540b 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -336,6 +336,8 @@ const ( SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" // SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false) SettingKeyEnableCCHSigning = "enable_cch_signing" + // SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false) + SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection" // Balance Low Notification SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go index e6c9de7d..e0c3cafd 100644 --- a/backend/internal/service/gateway_body_order_test.go +++ b/backend/internal/service/gateway_body_order_test.go @@ -1,13 +1,91 @@ package service import ( + "context" + "errors" "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) +type gatewayTTLSettingRepo struct { + data map[string]string +} + +func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) { + return nil, ErrSettingNotFound +} + +func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) { + if r == nil { + return "", ErrSettingNotFound + } + v, ok := r.data[key] + if !ok { + return "", ErrSettingNotFound + } + return v, nil +} + +func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error { + if r == nil { + return errors.New("setting repo is nil") + } + if r.data == nil { + r.data = map[string]string{} + } + r.data[key] = value + return nil +} + +func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + if r == nil { + return result, nil + } + for _, key := range keys { + if v, ok := r.data[key]; ok { + result[key] = v + } + } + return result, nil +} + +func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + if r == nil { + return errors.New("setting repo is nil") + } + if r.data == nil { + r.data = map[string]string{} + } + for key, value := range settings { + r.data[key] = value + } + return nil +} + +func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string) + if r == nil { + return result, nil + } + for key, value := range r.data { + result[key] = value + } + return result, nil +} + +func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error { + if r != nil { + delete(r.data, key) + } + return nil +} + func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) { t.Helper() @@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) { assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`)) } + +func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) { + body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`) + + result := injectAnthropicCacheControlTTL1h(body) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`) + require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String()) + require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) + require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists()) + require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) + require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String()) + require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String()) +} + +func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) { + repo := &gatewayTTLSettingRepo{data: map[string]string{ + SettingKeyEnableAnthropicCacheTTL1hInjection: "true", + }} + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + svc := &GatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth} + + target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account) + require.True(t, ok) + require.Equal(t, cacheTTLTarget5m, target) + + account.Extra = map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "1h", + } + target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account) + require.True(t, ok) + require.Equal(t, cacheTTLTarget1h, target) +} + +func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) { + repo := &gatewayTTLSettingRepo{data: map[string]string{ + SettingKeyEnableAnthropicCacheTTL1hInjection: "true", + }} + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + svc := &GatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } + + require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth})) + require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken})) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey})) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth})) + + repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false" + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth})) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d1f12009..074013c3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -62,6 +62,11 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +const ( + cacheTTLTarget5m = "5m" + cacheTTLTarget1h = "1h" +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} @@ -4226,6 +4231,87 @@ func enforceCacheControlLimit(body []byte) []byte { return body } +// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。 +// 仅修改已经存在的 cache_control,不新增缓存断点。 +func injectAnthropicCacheControlTTL1h(body []byte) []byte { + return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h) +} + +func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte { + if len(body) == 0 || ttl == "" { + return body + } + out := body + var paths []string + addPath := func(path string, value gjson.Result) { + cc := value.Get("cache_control") + if !cc.Exists() || cc.Get("type").String() != "ephemeral" { + return + } + if cc.Get("ttl").String() == ttl { + return + } + paths = append(paths, path+".cache_control.ttl") + } + + if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl { + paths = append(paths, "cache_control.ttl") + } + + system := gjson.GetBytes(body, "system") + if system.IsArray() { + idx := -1 + system.ForEach(func(_, block gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("system.%d", idx), block) + return true + }) + } + + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIdx := -1 + messages.ForEach(func(_, msg gjson.Result) bool { + msgIdx++ + content := msg.Get("content") + if !content.IsArray() { + return true + } + contentIdx := -1 + content.ForEach(func(_, block gjson.Result) bool { + contentIdx++ + addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block) + return true + }) + return true + }) + } + + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + idx := -1 + tools.ForEach(func(_, tool gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("tools.%d", idx), tool) + return true + }) + } + + for _, path := range paths { + if next, err := sjson.SetBytes(out, path, ttl); err == nil { + out = next + } + } + return out +} + +func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool { + if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil { + return false + } + return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() @@ -4385,6 +4471,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } + if s.shouldInjectAnthropicCacheTTL1h(ctx, account) { + body = injectAnthropicCacheControlTTL1h(body) + } + // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { @@ -7225,9 +7315,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } - // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { - overrideTarget := account.GetCacheTTLOverrideTarget() + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { @@ -7634,6 +7724,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { return true } +func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) { + if account == nil { + return "", false + } + if account.IsCacheTTLOverrideEnabled() { + return account.GetCacheTTLOverrideTarget(), true + } + if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) { + return cacheTTLTarget5m, true + } + return "", false +} + func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -7670,9 +7773,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } - // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { - overrideTarget := account.GetCacheTTLOverrideTarget() + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { @@ -8240,10 +8343,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + // Cache TTL Override: 确保计费时 token 分类与账号设置一致。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { - applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { + applyCacheTTLOverride(&result.Usage, overrideTarget) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 966b4b84..2bae686a 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second // cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL) type cachedGatewayForwardingSettings struct { - fingerprintUnification bool - metadataPassthrough bool - cchSigning bool - expiresAt int64 // unix nano + fingerprintUnification bool + metadataPassthrough bool + cchSigning bool + anthropicCacheTTL1hInjection bool + expiresAt int64 // unix nano } var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings @@ -1245,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection) updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) @@ -1305,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { }) gatewayForwardingSF.Forget("gateway_forwarding") gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ @@ -1415,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { return false } -// GetGatewayForwardingSettings returns cached gateway forwarding settings. -// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. -// Returns (fingerprintUnification, metadataPassthrough, cchSigning). -func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) { +type gatewayForwardingSettingsResult struct { + fp, mp, cch, cacheTTL1h bool +} + +func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + } } } - type gwfResult struct { - fp, mp, cch bool - } val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + }, nil } } dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) @@ -1439,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing SettingKeyEnableFingerprintUnification, SettingKeyEnableMetadataPassthrough, SettingKeyEnableCCHSigning, + SettingKeyEnableAnthropicCacheTTL1hInjection, }) if err != nil { slog.Warn("failed to get gateway forwarding settings", "error", err) gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: true, - metadataPassthrough: false, - cchSigning: false, - expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), + fingerprintUnification: true, + metadataPassthrough: false, + cchSigning: false, + anthropicCacheTTL1hInjection: false, + expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), }) - return gwfResult{true, false, false}, nil + return gatewayForwardingSettingsResult{fp: true}, nil } fp := true if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { @@ -1456,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing } mp := values[SettingKeyEnableMetadataPassthrough] == "true" cch := values[SettingKeyEnableCCHSigning] == "true" + cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: fp, - metadataPassthrough: mp, - cchSigning: cch, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + fingerprintUnification: fp, + metadataPassthrough: mp, + cchSigning: cch, + anthropicCacheTTL1hInjection: cacheTTL1h, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) - return gwfResult{fp, mp, cch}, nil + return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil }) - if r, ok := val.(gwfResult); ok { - return r.fp, r.mp, r.cch + if r, ok := val.(gatewayForwardingSettingsResult); ok { + return r } - return true, false, false // fail-open defaults + return gatewayForwardingSettingsResult{fp: true} +} + +// GetGatewayForwardingSettings returns cached gateway forwarding settings. +// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. +// Returns (fingerprintUnification, metadataPassthrough, cchSigning). +func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) { + result := s.getGatewayForwardingSettingsCached(ctx) + return result.fp, result.mp, result.cch +} + +// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。 +func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool { + return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h } // IsEmailVerifyEnabled 检查是否开启邮件验证 @@ -1880,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", - SettingPaymentVisibleMethodAlipaySource: "", - SettingPaymentVisibleMethodWxpaySource: "", - SettingPaymentVisibleMethodAlipayEnabled: "false", - SettingPaymentVisibleMethodWxpayEnabled: "false", - openAIAdvancedSchedulerSettingKey: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyEnableAnthropicCacheTTL1hInjection: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -2228,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" + result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" // Web search emulation: quick enabled check from the JSON config if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index c0962ff0..41c01cca 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -149,9 +149,10 @@ type SystemSettings struct { BackendModeEnabled bool // Gateway forwarding behavior - EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) - EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) - EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) + EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) + EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false) // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index e8ab6af5..35eef9de 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -439,6 +439,7 @@ export interface SystemSettings { enable_fingerprint_unification: boolean; enable_metadata_passthrough: boolean; enable_cch_signing: boolean; + enable_anthropic_cache_ttl_1h_injection: boolean; web_search_emulation_enabled?: boolean; // Payment configuration @@ -609,6 +610,7 @@ export interface UpdateSettingsRequest { enable_fingerprint_unification?: boolean; enable_metadata_passthrough?: boolean; enable_cch_signing?: boolean; + enable_anthropic_cache_ttl_1h_injection?: boolean; // Payment configuration payment_enabled?: boolean; payment_min_amount?: number; diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 0425955f..2da121fb 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -5019,6 +5019,8 @@ export default { metadataPassthroughHint: 'Pass through client\'s original metadata.user_id without rewriting. May improve upstream cache hit rates.', cchSigning: 'CCH Signing', cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.', + anthropicCacheTTL1hInjection: 'Anthropic Cache TTL Injection', + anthropicCacheTTL1hInjectionHint: 'When enabled, existing ephemeral cache_control blocks in Anthropic OAuth/Setup Token request bodies are forced to 1h; response usage is billed back as 5m by default, with account-level TTL billing override taking priority.', }, webSearchEmulation: { title: 'Web Search Emulation', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a8656a7b..7d266522 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -5178,6 +5178,8 @@ export default { metadataPassthroughHint: '透传客户端原始 metadata.user_id,不进行重写。可能提高上游缓存命中率。', cchSigning: 'CCH 签名', cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。', + anthropicCacheTTL1hInjection: 'Anthropic 缓存 TTL 注入', + anthropicCacheTTL1hInjectionHint: '开启后,对 Anthropic OAuth/Setup Token 请求体中已有的 ephemeral 缓存块强制写入 1h;响应 usage 默认按 5m 回写计费,账号级 TTL 计费设置优先。', }, webSearchEmulation: { title: 'Web Search 模拟', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index ad0587b8..13cb0b2c 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3057,6 +3057,31 @@ + + +
+
+ +

+ {{ + t( + "admin.settings.gatewayForwarding.anthropicCacheTTL1hInjectionHint", + ) + }} +

+
+ +
@@ -5810,6 +5835,7 @@ const form = reactive({ enable_fingerprint_unification: true, enable_metadata_passthrough: false, enable_cch_signing: false, + enable_anthropic_cache_ttl_1h_injection: false, // Balance & quota notification balance_low_notify_enabled: false, balance_low_notify_threshold: 0, @@ -6718,6 +6744,8 @@ async function saveSettings() { enable_fingerprint_unification: form.enable_fingerprint_unification, enable_metadata_passthrough: form.enable_metadata_passthrough, enable_cch_signing: form.enable_cch_signing, + enable_anthropic_cache_ttl_1h_injection: + form.enable_anthropic_cache_ttl_1h_injection, // Payment configuration payment_enabled: form.payment_enabled, payment_min_amount: Number(form.payment_min_amount) || 0, diff --git a/frontend/src/views/admin/__tests__/SettingsView.spec.ts b/frontend/src/views/admin/__tests__/SettingsView.spec.ts index 239c474e..4ab475ad 100644 --- a/frontend/src/views/admin/__tests__/SettingsView.spec.ts +++ b/frontend/src/views/admin/__tests__/SettingsView.spec.ts @@ -362,6 +362,7 @@ const baseSettingsResponse = { enable_fingerprint_unification: true, enable_metadata_passthrough: false, enable_cch_signing: false, + enable_anthropic_cache_ttl_1h_injection: false, payment_enabled: true, payment_min_amount: 1, payment_max_amount: 10000, @@ -567,6 +568,26 @@ describe("admin SettingsView payment visible method controls", () => { expect(payload).not.toHaveProperty("payment_visible_method_wxpay_enabled"); }); + it("submits Anthropic cache TTL injection gateway setting", async () => { + getSettings.mockResolvedValueOnce({ + ...baseSettingsResponse, + enable_anthropic_cache_ttl_1h_injection: true, + }); + + const wrapper = mountView(); + + await flushPromises(); + await wrapper.find("form").trigger("submit.prevent"); + await flushPromises(); + + expect(updateSettings).toHaveBeenCalledTimes(1); + expect(updateSettings).toHaveBeenCalledWith( + expect.objectContaining({ + enable_anthropic_cache_ttl_1h_injection: true, + }), + ); + }); + it("updates provider enablement immediately and reloads providers", async () => { const provider = { id: 7, From 9d801595c95eb5f5616bca0ec409a42d73325987 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 30 Apr 2026 13:48:27 +0800 Subject: [PATCH 45/46] =?UTF-8?q?test:=20=E6=9B=B4=E6=96=B0=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=91=98=E8=AE=BE=E7=BD=AE=E5=A5=91=E7=BA=A6=E5=AD=97?= =?UTF-8?q?=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/server/api_contract_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f24a1677..607b93dc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -740,6 +740,7 @@ func TestAPIContracts(t *testing.T) { "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, "enable_cch_signing": false, + "enable_anthropic_cache_ttl_1h_injection": false, "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, @@ -934,6 +935,7 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "enable_cch_signing": false, + "enable_anthropic_cache_ttl_1h_injection": false, "web_search_emulation_enabled": false, "payment_visible_method_alipay_source": "", "payment_visible_method_wxpay_source": "", From 48912014a16e2dd1cfca8b7cad785d0e8e7bfeec Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 06:06:12 +0000 Subject: [PATCH 46/46] chore: sync VERSION to 0.1.121 [skip ci] --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 27f3bc3e..025c3166 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.120 +0.1.121