From c810cad7c860b9a6af7e1dfae106c92bffa3a11c Mon Sep 17 00:00:00 2001 From: Remx Date: Thu, 19 Mar 2026 19:00:22 +0800 Subject: [PATCH 001/125] =?UTF-8?q?feat(openai):=20=E5=A2=9E=E5=8A=A0=20gp?= =?UTF-8?q?t-5.4-mini/nano=20=E6=A8=A1=E5=9E=8B=E6=94=AF=E6=8C=81=E4=B8=8E?= =?UTF-8?q?=E5=AE=9A=E4=BB=B7=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 接入 gpt-5.4-mini/nano 模型识别与规范化,补充默认模型列表 - 增加 gpt-5.4-mini/nano 输入/缓存命中/输出价格与计费兜底逻辑 - 同步前端模型白名单与 OpenCode 配置 - 补充 service tier(priority/flex) 计费回归测试 --- backend/internal/pkg/openai/constants.go | 2 + backend/internal/service/billing_service.go | 16 +++++ .../internal/service/billing_service_test.go | 60 +++++++++++++++++ .../service/openai_codex_transform.go | 8 +++ .../service/openai_codex_transform_test.go | 4 ++ backend/internal/service/pricing_service.go | 28 ++++++++ .../internal/service/pricing_service_test.go | 30 +++++++++ .../model_prices_and_context_window.json | 65 +++++++++++++++++++ frontend/src/components/keys/UseKeyModal.vue | 32 +++++++++ .../keys/__tests__/UseKeyModal.spec.ts | 53 +++++++++++++++ .../__tests__/useModelWhitelist.spec.ts | 11 ++++ frontend/src/composables/useModelWhitelist.ts | 2 +- 12 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 frontend/src/components/keys/__tests__/UseKeyModal.spec.ts diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index b0a31a5f..49e38bf8 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -16,6 +16,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, + {ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"}, + {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 68d7a8f9..99fea0b0 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -221,6 +221,18 @@ func (s *BillingService) initFallbackPricing() { LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, } + s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ + InputPricePerToken: 7.5e-7, + OutputPricePerToken: 4.5e-6, + CacheReadPricePerToken: 7.5e-8, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ + InputPricePerToken: 2e-7, + OutputPricePerToken: 1.25e-6, + CacheReadPricePerToken: 2e-8, + SupportsCacheBreakdown: false, + } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -294,6 +306,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { normalized := normalizeCodexModel(modelLower) switch normalized { + case "gpt-5.4-mini": + return s.fallbackPrices["gpt-5.4-mini"] + case "gpt-5.4-nano": + return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 45bbdcee..10943422 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -174,6 +174,30 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) } +func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4-mini") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 7.5e-7, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 4.5e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 7.5e-8, pricing.CacheReadPricePerToken, 1e-12) + require.Zero(t, pricing.LongContextInputThreshold) +} + +func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricing("gpt-5.4-nano") + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12) + require.Zero(t, pricing.LongContextInputThreshold) +} + func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { svc := newTestBillingService() @@ -210,6 +234,8 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, + {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, + {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, @@ -564,6 +590,40 @@ func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) { require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) } +func TestCalculateCostWithServiceTier_Gpt54MiniPriorityFallsBackToTierMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} + + baseCost, err := svc.CalculateCost("gpt-5.4-mini", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.4-mini", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_Gpt54NanoFlexAppliesHalfMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.4-nano", tokens, 1.0) + require.NoError(t, err) + + flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4-nano", tokens, 1.0, "flex") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) +} + func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) { svc := newTestBillingService() tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 0ae55ad3..d0534d8c 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -7,6 +7,8 @@ import ( var codexModelMap = map[string]string{ "gpt-5.4": "gpt-5.4", + "gpt-5.4-mini": "gpt-5.4-mini", + "gpt-5.4-nano": "gpt-5.4-nano", "gpt-5.4-none": "gpt-5.4", "gpt-5.4-low": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4", @@ -225,6 +227,12 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { + return "gpt-5.4-mini" + } + if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") { + return "gpt-5.4-nano" + } if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return "gpt-5.4" } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index b52f0566..eab88c09 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -238,6 +238,10 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.4-high": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4", "gpt 5.4": "gpt-5.4", + "gpt-5.4-mini": "gpt-5.4-mini", + "gpt 5.4 mini": "gpt-5.4-mini", + "gpt-5.4-nano": "gpt-5.4-nano", + "gpt 5.4 nano": "gpt-5.4-nano", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 7ed4e7e4..10440c60 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -34,6 +34,22 @@ var ( Mode: "chat", SupportsPromptCaching: true, } + openAIGPT54MiniFallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 7.5e-07, + OutputCostPerToken: 4.5e-06, + CacheReadInputTokenCost: 7.5e-08, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } + openAIGPT54NanoFallbackPricing = &LiteLLMModelPricing{ + InputCostPerToken: 2e-07, + OutputCostPerToken: 1.25e-06, + CacheReadInputTokenCost: 2e-08, + LiteLLMProvider: "openai", + Mode: "chat", + SupportsPromptCaching: true, + } ) // LiteLLMModelPricing LiteLLM价格数据结构 @@ -723,6 +739,18 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + if strings.HasPrefix(model, "gpt-5.4-mini") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)")) + return openAIGPT54MiniFallbackPricing + } + + if strings.HasPrefix(model, "gpt-5.4-nano") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-nano(static)")) + return openAIGPT54NanoFallbackPricing + } + if strings.HasPrefix(model, "gpt-5.4") { logger.With(zap.String("component", "service.pricing")). Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 775024fd..13a5c70c 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -98,6 +98,36 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) } +func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4-mini") + require.NotNil(t, got) + require.InDelta(t, 7.5e-7, got.InputCostPerToken, 1e-12) + require.InDelta(t, 4.5e-6, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 7.5e-8, got.CacheReadInputTokenCost, 1e-12) + require.Zero(t, got.LongContextInputTokenThreshold) +} + +func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("gpt-5.4-nano") + require.NotNil(t, got) + require.InDelta(t, 2e-7, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.25e-6, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2e-8, got.CacheReadInputTokenCost, 1e-12) + require.Zero(t, got.LongContextInputTokenThreshold) +} + func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { raw := map[string]any{ "gpt-5.4": map[string]any{ diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 72860bf9..0a096257 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5173,6 +5173,71 @@ "supports_tool_choice": true, "supports_vision": true }, + "gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "openai", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, "cache_read_input_token_cost_priority": 3.5e-07, diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 634db115..7770e658 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -709,6 +709,38 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, + 'gpt-5.4-mini': { + name: 'GPT-5.4 Mini', + limit: { + context: 400000, + output: 128000 + }, + options: { + store: false + }, + variants: { + low: {}, + medium: {}, + high: {}, + xhigh: {} + } + }, + 'gpt-5.4-nano': { + name: 'GPT-5.4 Nano', + limit: { + context: 400000, + output: 128000 + }, + options: { + store: false + }, + variants: { + low: {}, + medium: {}, + high: {}, + xhigh: {} + } + }, 'gpt-5.3-codex-spark': { name: 'GPT-5.3 Codex Spark', limit: { diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts new file mode 100644 index 00000000..98b5dede --- /dev/null +++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts @@ -0,0 +1,53 @@ +import { describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import { nextTick } from 'vue' + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) +})) + +vi.mock('@/composables/useClipboard', () => ({ + useClipboard: () => ({ + copyToClipboard: vi.fn().mockResolvedValue(true) + }) +})) + +import UseKeyModal from '../UseKeyModal.vue' + +describe('UseKeyModal', () => { + it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => { + const wrapper = mount(UseKeyModal, { + props: { + show: true, + apiKey: 'sk-test', + baseUrl: 'https://example.com/v1', + platform: 'openai' + }, + global: { + stubs: { + BaseDialog: { + template: '
' + }, + Icon: { + template: '' + } + } + } + }) + + const opencodeTab = wrapper.findAll('button').find((button) => + button.text().includes('keys.useKeyModal.cliTabs.opencode') + ) + + expect(opencodeTab).toBeDefined() + await opencodeTab!.trigger('click') + await nextTick() + + const codeBlock = wrapper.find('pre code') + expect(codeBlock.exists()).toBe(true) + expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"') + expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"') + }) +}) diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index b4308a63..4061be4d 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -11,6 +11,8 @@ describe('useModelWhitelist', () => { const models = getModelsByPlatform('openai') expect(models).toContain('gpt-5.4') + expect(models).toContain('gpt-5.4-mini') + expect(models).toContain('gpt-5.4-nano') expect(models).toContain('gpt-5.4-2026-03-05') }) @@ -52,4 +54,13 @@ describe('useModelWhitelist', () => { 'gpt-5.4-2026-03-05': 'gpt-5.4-2026-03-05' }) }) + + it('whitelist keeps GPT-5.4 mini and nano exact mappings', () => { + const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini', 'gpt-5.4-nano'], []) + + expect(mapping).toEqual({ + 'gpt-5.4-mini': 'gpt-5.4-mini', + 'gpt-5.4-nano': 'gpt-5.4-nano' + }) + }) }) diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 0ff288bb..9e7cb036 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -25,7 +25,7 @@ const openaiModels = [ 'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest', 'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', // GPT-5.4 系列 - 'gpt-5.4', 'gpt-5.4-2026-03-05', + 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.4-2026-03-05', // GPT-5.3 系列 'gpt-5.3-codex', 'gpt-5.3-codex-spark', 'chatgpt-4o-latest', From 995ef1348a63d6ee238cbb5fd9aa4ff4cb1d6fcc Mon Sep 17 00:00:00 2001 From: InCerry Date: Tue, 24 Mar 2026 19:20:15 +0800 Subject: [PATCH 002/125] refactor: improve model resolution and normalization logic for OpenAI integration --- .../service/openai_codex_transform.go | 2 +- .../service/openai_codex_transform_test.go | 29 ++++++++++++++ .../service/openai_compat_prompt_cache_key.go | 8 ++-- .../openai_compat_prompt_cache_key_test.go | 15 +++++++ .../openai_gateway_chat_completions.go | 33 +++++++++------ .../service/openai_gateway_messages.go | 29 +++++++++----- .../service/openai_gateway_service.go | 40 +++++++++---------- .../internal/service/openai_model_mapping.go | 28 +++++++++++-- .../service/openai_model_mapping_test.go | 29 +++++++++++--- .../internal/service/openai_ws_forwarder.go | 14 ++----- 10 files changed, 159 insertions(+), 68 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index d0534d8c..21b4874e 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if v, ok := reqBody["model"].(string); ok { model = v } - normalizedModel := normalizeCodexModel(model) + normalizedModel := strings.TrimSpace(model) if normalizedModel != "" { if model != normalizedModel { reqBody["model"] = normalizedModel diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index eab88c09..889ac615 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt 5.3 codex spark": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex", @@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) +} + +func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) { + reqBody := map[string]any{ + "model": " gpt-5.3-codex-spark ", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { // Codex CLI 场景:已有 instructions 时不修改 diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 88e16a4d..46381838 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,8 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch normalizeCodexModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex": + switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return true default: return false @@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod return "" } - normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) + normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel)) if normalizedModel == "" { - normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) + normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model)) } if normalizedModel == "" { normalizedModel = strings.TrimSpace(req.Model) diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go index eb9148de..6ca3e85c 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key_test.go +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) } @@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") require.NotEqual(t, k1, k2, "different first user messages should yield different keys") } + +func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) { + req := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.3-codex-spark", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark") + k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index a442da33..1d5bf0d0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) promptCacheKey = strings.TrimSpace(promptCacheKey) compatPromptCacheInjected := false - if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) { - promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel) + if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel) compatPromptCacheInjected = promptCacheKey != "" } @@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( if err != nil { return nil, fmt.Errorf("convert chat completions to responses: %w", err) } - responsesReq.Model = mappedModel + responsesReq.Model = upstreamModel logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), } if compatPromptCacheInjected { @@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) } else { - result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, includeUsage bool, startTime time.Time, ) (*OpenAIForwardResult, error) { @@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 6a29823a..e9548b79 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -59,13 +59,15 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) - responsesReq.Model = mappedModel + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) + responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), ) @@ -81,6 +83,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -181,10 +186,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } else { // Client wants JSON: buffer the streaming response and assemble a JSON reply. - result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -229,7 +234,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -302,8 +308,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -318,7 +324,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -351,8 +358,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4e96cf05..daccf38f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1778,29 +1778,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 对所有请求执行模型映射(包含 Codex CLI)。 - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) - reqBody["model"] = mappedModel + billingModel := account.GetMappedModel(reqModel) + if billingModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) + reqBody["model"] = billingModel bodyModified = true - markPatchSet("model", mappedModel) + markPatchSet("model", billingModel) } + upstreamModel := billingModel // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { - normalizedModel := normalizeCodexModel(model) - if normalizedModel != "" && normalizedModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, normalizedModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = normalizedModel - mappedModel = normalizedModel + upstreamModel = resolveOpenAIUpstreamModel(model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel bodyModified = true - markPatchSet("model", normalizedModel) + markPatchSet("model", upstreamModel) } // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 // 确保高版本模型向低版本模型映射不报错 - if !SupportsVerbosity(normalizedModel) { + if !SupportsVerbosity(upstreamModel) { if text, ok := reqBody["text"].(map[string]any); ok { delete(text, "verbosity") } @@ -1824,7 +1824,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } if codexResult.NormalizedModel != "" { - mappedModel = codexResult.NormalizedModel + upstreamModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey @@ -1941,7 +1941,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", account.ID, account.Type, - mappedModel, + upstreamModel, reqStream, hasPreviousResponseID, ) @@ -2030,7 +2030,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco isCodexCLI, reqStream, originalModel, - mappedModel, + upstreamModel, startTime, attempt, wsLastFailureReason, @@ -2131,7 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) - wsResult.UpstreamModel = mappedModel + wsResult.UpstreamModel = upstreamModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2236,14 +2236,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var usage *OpenAIUsage var firstTokenMs *int if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } @@ -2267,7 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, - UpstreamModel: mappedModel, + UpstreamModel: upstreamModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 9bf3fba3..4f8c094b 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,8 +1,10 @@ package service -// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible -// forwarding. Group-level default mapping only applies when the account itself -// did not match any explicit model_mapping rule. +import "strings" + +// resolveOpenAIForwardModel resolves the account/group mapping result for +// OpenAI-compatible forwarding. Group-level default mapping only applies when +// the account itself did not match any explicit model_mapping rule. func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { if defaultMappedModel != "" { @@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } return mappedModel } + +func resolveOpenAIUpstreamModel(model string) string { + if isBareGPT53CodexSparkModel(model) { + return "gpt-5.3-codex-spark" + } + return normalizeCodexModel(strings.TrimSpace(model)) +} + +func isBareGPT53CodexSparkModel(model string) bool { + modelID := strings.TrimSpace(model) + if modelID == "" { + return false + } + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + normalized := strings.ToLower(strings.TrimSpace(modelID)) + return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark" +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index edbb968b..42f58b37 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -74,13 +74,30 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * Credentials: map[string]any{}, } - withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") - if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1") + withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) + if withoutDefault != "gpt-5.1" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") } - withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") - if got := normalizeCodexModel(withDefault); got != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4") + withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + if withDefault != "gpt-5.4" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") + } +} + +func TestResolveOpenAIUpstreamModel(t *testing.T) { + cases := map[string]string{ + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + " openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + } + + for input, expected := range cases { + if got := resolveOpenAIUpstreamModel(input); got != expected { + t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) + } } } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 814ec0bd..9c30a390 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - mappedModel := account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } - if mappedModel != originalModel { - next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + if upstreamModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } @@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } + mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) From ad2cd97618791bd93ffa04b9497da3613b9ae796 Mon Sep 17 00:00:00 2001 From: haruka <1628615876@qq.com> Date: Mon, 30 Mar 2026 16:23:38 +0800 Subject: [PATCH 003/125] fix: resolve refresh token race condition causing false invalid_grant errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When multiple goroutines/workers concurrently refresh the same OAuth token, the first succeeds but invalidates the old refresh_token (rotation). Subsequent attempts using the stale token get invalid_grant, which was incorrectly treated as non-retryable, permanently marking the account as ERROR. Three complementary fixes: 1. Race-aware recovery: after invalid_grant, re-read DB to check if another worker already refreshed (refresh_token changed) — return success instead of error 2. In-process mutex (sync.Map of per-account locks): prevents concurrent refreshes within the same process, complementing the Redis distributed lock 3. Increase default lock TTL from 30s to 60s to reduce TTL-expiry races Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/internal/service/oauth_refresh_api.go | 73 +++++- .../service/oauth_refresh_api_test.go | 219 ++++++++++++++++++ 2 files changed, 286 insertions(+), 6 deletions(-) diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 5dbba638..fdc0ce40 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "strconv" + "strings" + "sync" "time" ) @@ -17,7 +19,7 @@ type OAuthRefreshExecutor interface { CacheKey(account *Account) string } -const refreshLockTTL = 30 * time.Second +const defaultRefreshLockTTL = 60 * time.Second // OAuthRefreshResult 统一刷新结果 type OAuthRefreshResult struct { @@ -28,20 +30,34 @@ type OAuthRefreshResult struct { } // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 -// 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑 type OAuthRefreshAPI struct { accountRepo AccountRepository - tokenCache GeminiTokenCache // 可选,nil = 无锁 + tokenCache GeminiTokenCache // 可选,nil = 无分布式锁 + lockTTL time.Duration + localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex } // NewOAuthRefreshAPI 创建统一刷新 API -func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { +// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL +func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI { + ttl := defaultRefreshLockTTL + if len(lockTTL) > 0 && lockTTL[0] > 0 { + ttl = lockTTL[0] + } return &OAuthRefreshAPI{ accountRepo: accountRepo, tokenCache: tokenCache, + lockTTL: ttl, } } +// getLocalLock 返回指定 cacheKey 的进程内互斥锁 +func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex { + val, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{}) + return val.(*sync.Mutex) +} + // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token // // 流程: @@ -59,12 +75,17 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) + // 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争) + localMu := api.getLocalLock(cacheKey) + localMu.Lock() + defer localMu.Unlock() + // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { - acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) + acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL) if lockErr != nil { - // Redis 错误,降级为无锁刷新 + // Redis 错误,降级为无锁刷新(进程内互斥锁仍生效) slog.Warn("oauth_refresh_lock_failed_degraded", "account_id", account.ID, "cache_key", cacheKey, @@ -102,6 +123,19 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 4. 执行平台特定刷新逻辑 newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) if refreshErr != nil { + // 竞争恢复:invalid_grant 可能是另一个 worker 已消费了旧 refresh_token + // 重新读取 DB,如果 refresh_token 已更新则说明是竞争,返回成功 + if isInvalidGrantError(refreshErr) { + if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered { + slog.Info("oauth_refresh_race_recovered", + "account_id", freshAccount.ID, + "platform", freshAccount.Platform, + ) + return &OAuthRefreshResult{ + Account: recoveredAccount, + }, nil + } + } return nil, refreshErr } @@ -126,6 +160,33 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( }, nil } +// isInvalidGrantError 检查错误是否为 invalid_grant +func isInvalidGrantError(err error) bool { + return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant") +} + +// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复 +// 重新读取 DB,如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account +func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) { + if api.accountRepo == nil { + return nil, false + } + reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID) + if err != nil || reReadAccount == nil { + return nil, false + } + usedRT := usedAccount.GetCredential("refresh_token") + currentRT := reReadAccount.GetCredential("refresh_token") + if usedRT == "" || currentRT == "" { + return nil, false + } + // refresh_token 不同 → 另一个 worker 已成功刷新 + if usedRT != currentRT { + return reReadAccount, true + } + return nil, false +} + // MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { if newCreds == nil { diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index c3b38ddf..4a60723b 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -5,6 +5,7 @@ package service import ( "context" "errors" + "sync" "testing" "time" @@ -385,6 +386,224 @@ func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) { require.False(t, hasScope, "scope should not be set when empty") } +// refreshAPIAccountRepoWithRace supports returning a different account on subsequent GetByID calls +// to simulate race conditions where another worker has refreshed the token. +type refreshAPIAccountRepoWithRace struct { + refreshAPIAccountRepo + raceAccount *Account // returned on 2nd+ GetByID call + getByIDCalls int +} + +func (r *refreshAPIAccountRepoWithRace) GetByID(_ context.Context, _ int64) (*Account, error) { + r.getByIDCalls++ + if r.getByIDCalls > 1 && r.raceAccount != nil { + return r.raceAccount, nil + } + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.account, nil +} + +// ========== Race recovery tests ========== + +func TestRefreshIfNeeded_InvalidGrantRaceRecovered(t *testing.T) { + // Account with old refresh token + account := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt", "access_token": "old-at"}, + } + // After race, DB has new refresh token from another worker + racedAccount := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: racedAccount, + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token not found or invalid"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err, "race-recovered invalid_grant should not return error") + require.False(t, result.Refreshed) + require.False(t, result.LockHeld) + require.NotNil(t, result.Account) + require.Equal(t, "new-rt", result.Account.GetCredential("refresh_token")) + require.Equal(t, 0, repo.updateCalls) // no DB update needed, another worker did it +} + +func TestRefreshIfNeeded_InvalidGrantGenuine(t *testing.T) { + // Account with revoked refresh token - DB still has the same token + account := &Account{ + ID: 11, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "revoked-rt", "access_token": "old-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: account, // same refresh_token on re-read + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token revoked"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "genuine invalid_grant should propagate error") + require.Nil(t, result) + require.Contains(t, err.Error(), "invalid_grant") +} + +func TestRefreshIfNeeded_InvalidGrantDBRereadFailsOnRecovery(t *testing.T) { + account := &Account{ + ID: 12, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: nil, // GetByID returns nil on recovery attempt + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "should propagate error when recovery DB re-read fails") + require.Nil(t, result) +} + +func TestRefreshIfNeeded_LocalMutexSerializesConcurrent(t *testing.T) { + // Test that two goroutines for the same account are serialized by the local mutex. + // The first goroutine refreshes successfully; the second sees NeedsRefresh=false. + refreshed := &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + callCount := 0 + repo := &refreshAPIAccountRepo{account: &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + }} + + // After first refresh, NeedsRefresh should return false + // We simulate this by using an executor that decrements needsRefresh after first call + var mu sync.Mutex + dynamicExecutor := &dynamicRefreshExecutor{ + canRefresh: true, + cacheKey: "test:mutex:anthropic", + refreshFunc: func(_ context.Context, _ *Account) (map[string]any, error) { + mu.Lock() + callCount++ + mu.Unlock() + time.Sleep(50 * time.Millisecond) // slow refresh + return map[string]any{"access_token": "new-at"}, nil + }, + needsRefreshFunc: func() bool { + mu.Lock() + defer mu.Unlock() + return callCount == 0 // only first call needs refresh + }, + } + + _ = refreshed + + api := NewOAuthRefreshAPI(repo, nil) // no distributed lock, only local mutex + + var wg sync.WaitGroup + results := make([]*OAuthRefreshResult, 2) + errs := make([]error, 2) + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx], errs[idx] = api.RefreshIfNeeded(context.Background(), repo.account, dynamicExecutor, 3*time.Minute) + }(i) + } + wg.Wait() + + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + // Only one goroutine should have actually called Refresh + mu.Lock() + require.Equal(t, 1, callCount, "only one refresh call should have been made") + mu.Unlock() +} + +// dynamicRefreshExecutor is a test helper with function-based NeedsRefresh and Refresh. +type dynamicRefreshExecutor struct { + canRefresh bool + cacheKey string + needsRefreshFunc func() bool + refreshFunc func(context.Context, *Account) (map[string]any, error) +} + +func (e *dynamicRefreshExecutor) CanRefresh(_ *Account) bool { return e.canRefresh } + +func (e *dynamicRefreshExecutor) NeedsRefresh(_ *Account, _ time.Duration) bool { + return e.needsRefreshFunc() +} + +func (e *dynamicRefreshExecutor) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + return e.refreshFunc(ctx, account) +} + +func (e *dynamicRefreshExecutor) CacheKey(_ *Account) string { + return e.cacheKey +} + +// ========== NewOAuthRefreshAPI TTL tests ========== + +func TestNewOAuthRefreshAPI_DefaultTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_CustomTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 90*time.Second) + require.Equal(t, 90*time.Second, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_ZeroTTLUsesDefault(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 0) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +// ========== isInvalidGrantError tests ========== + +func TestIsInvalidGrantError(t *testing.T) { + require.True(t, isInvalidGrantError(errors.New("invalid_grant: token revoked"))) + require.True(t, isInvalidGrantError(errors.New("INVALID_GRANT"))) + require.False(t, isInvalidGrantError(errors.New("invalid_client"))) + require.False(t, isInvalidGrantError(nil)) +} + // ========== BackgroundRefreshPolicy tests ========== func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) { From 49e99e9d519a55cbe6d3bd94b810978dd64ce4b8 Mon Sep 17 00:00:00 2001 From: haruka <1628615876@qq.com> Date: Mon, 30 Mar 2026 16:44:15 +0800 Subject: [PATCH 004/125] fix: resolve errcheck lint for sync.Map type assertion Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/internal/service/oauth_refresh_api.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index fdc0ce40..571e9ecd 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -54,8 +54,13 @@ func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCac // getLocalLock 返回指定 cacheKey 的进程内互斥锁 func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex { - val, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{}) - return val.(*sync.Mutex) + actual, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{}) + mu, ok := actual.(*sync.Mutex) + if !ok { + mu = &sync.Mutex{} + api.localLocks.Store(cacheKey, mu) + } + return mu } // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token From 6b646b61273e8cc0b6ec93b8d179e8539e2651e5 Mon Sep 17 00:00:00 2001 From: qingyuzhang Date: Mon, 30 Mar 2026 22:29:26 +0800 Subject: [PATCH 005/125] fix(openai): fail over passthrough 429 and 529 --- .../service/openai_gateway_service.go | 58 +++++- .../service/openai_oauth_passthrough_test.go | 190 ++++++++++++++---- 2 files changed, 206 insertions(+), 42 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 0a959615..d0ae5a2f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2430,7 +2430,11 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + // 透传模式默认保持原样代理;但 429/529 属于网关必须兜底的 + // 上游容量类错误,应先触发多账号 failover 以维持基础 SLA。 + if shouldFailoverOpenAIPassthroughResponse(resp.StatusCode) { + return nil, s.handleFailoverErrorResponsePassthrough(ctx, resp, c, account, body) + } return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) } @@ -2613,6 +2617,58 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( return req, nil } +func shouldFailoverOpenAIPassthroughResponse(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, 529: + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + return &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + ResponseHeaders: resp.Header.Clone(), + } +} + func (s *OpenAIGatewayService) handleErrorResponsePassthrough( ctx context.Context, resp *http.Response, diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 97fa218d..69c9de42 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -48,6 +48,22 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc return u.Do(req, proxyURL, accountID, accountConcurrency) } +type openAIPassthroughFailoverRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + overloadCalls []time.Time +} + +func (r *openAIPassthroughFailoverRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIPassthroughFailoverRepo) SetOverloaded(_ context.Context, _ int64, until time.Time) error { + r.overloadCalls = append(r.overloadCalls, until) + return nil +} + var structuredLogCaptureMu sync.Mutex type inMemoryLogSink struct { @@ -527,6 +543,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF _, err := svc.Forward(context.Background(), c, account, originalBody) require.Error(t, err) + require.True(t, c.Writer.Written(), "非 429/529 的 passthrough 错误应继续原样写回客户端") + require.Equal(t, http.StatusBadRequest, rec.Code) // should append an upstream error event with passthrough=true v, ok := c.Get(OpsUpstreamErrorsKey) @@ -535,55 +553,145 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, ok) require.NotEmpty(t, arr) require.True(t, arr[len(arr)-1].Passthrough) + require.Equal(t, "http_error", arr[len(arr)-1].Kind) } -func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) { +func TestOpenAIGatewayService_OpenAIPassthrough_429And529TriggerFailover(t *testing.T) { gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) - resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{ - "Content-Type": []string{"application/json"}, - "x-request-id": []string{"rid-rate-limit"}, + + newAccount := func(accountType string) *Account { + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: accountType, + Concurrency: 1, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + switch accountType { + case AccountTypeOAuth: + account.Credentials = map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"} + case AccountTypeAPIKey: + account.Credentials = map[string]any{"api_key": "sk-test"} + } + return account + } + + testCases := []struct { + name string + accountType string + statusCode int + body string + assertRepo func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) + }{ + { + name: "oauth_429_rate_limit", + accountType: AccountTypeOAuth, + statusCode: http.StatusTooManyRequests, + body: func() string { + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt) + }(), + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) { + require.Len(t, repo.rateLimitCalls, 1) + require.Empty(t, repo.overloadCalls) + require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour) + }, + }, + { + name: "oauth_529_overload", + accountType: AccountTypeOAuth, + statusCode: 529, + body: `{"error":{"message":"server overloaded","type":"server_error"}}`, + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) { + require.Empty(t, repo.rateLimitCalls) + require.Len(t, repo.overloadCalls, 1) + require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second) + }, + }, + { + name: "apikey_429_rate_limit", + accountType: AccountTypeAPIKey, + statusCode: http.StatusTooManyRequests, + body: func() string { + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt) + }(), + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) { + require.Len(t, repo.rateLimitCalls, 1) + require.Empty(t, repo.overloadCalls) + require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour) + }, + }, + { + name: "apikey_529_overload", + accountType: AccountTypeAPIKey, + statusCode: 529, + body: `{"error":{"message":"server overloaded","type":"server_error"}}`, + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) { + require.Empty(t, repo.rateLimitCalls) + require.Len(t, repo.overloadCalls, 1) + require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second) + }, }, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))), - } - upstream := &httpUpstreamRecorder{resp: resp} - repo := &openAIWSRateLimitSignalRepo{} - rateSvc := &RateLimitService{accountRepo: repo} - - svc := &OpenAIGatewayService{ - cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, - httpUpstream: upstream, - rateLimitService: rateSvc, } - account := &Account{ - ID: 123, - Name: "acc", - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_passthrough": true}, - Status: StatusActive, - Schedulable: true, - RateMultiplier: f64p(1), - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - _, err := svc.Forward(context.Background(), c, account, originalBody) - require.Error(t, err) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - require.Contains(t, rec.Body.String(), "usage_limit_reached") - require.Len(t, repo.rateLimitCalls, 1) - require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + resp := &http.Response{ + StatusCode: tc.statusCode, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-failover"}, + }, + Body: io.NopCloser(strings.NewReader(tc.body)), + } + upstream := &httpUpstreamRecorder{resp: resp} + repo := &openAIPassthroughFailoverRepo{} + rateSvc := &RateLimitService{ + accountRepo: repo, + cfg: &config.Config{ + RateLimit: config.RateLimitConfig{OverloadCooldownMinutes: 10}, + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + rateLimitService: rateSvc, + } + + account := newAccount(tc.accountType) + start := time.Now() + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, tc.statusCode, failoverErr.StatusCode) + require.False(t, c.Writer.Written(), "429/529 passthrough 应返回 failover 错误给上层换号,而不是直接向客户端写响应") + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + arr, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, arr) + require.True(t, arr[len(arr)-1].Passthrough) + require.Equal(t, "failover", arr[len(arr)-1].Kind) + require.Equal(t, tc.statusCode, arr[len(arr)-1].UpstreamStatusCode) + + tc.assertRepo(t, repo, start) + }) + } } func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { From a61d58716fb7a15a2bbedc5e7d8260f5b7bf0974 Mon Sep 17 00:00:00 2001 From: weak-fox <827367480@qq.com> Date: Tue, 31 Mar 2026 00:00:46 +0800 Subject: [PATCH 006/125] fix(admin): exclude rate-limited accounts from active filter --- backend/internal/repository/account_repo.go | 8 ++++++++ .../repository/account_repo_integration_test.go | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a12..28aca32b 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -468,6 +468,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati } if status != "" { switch status { + case service.StatusActive: + q = q.Where( + dbaccount.StatusEQ(status), + dbaccount.Or( + dbaccount.RateLimitResetAtIsNil(), + dbaccount.RateLimitResetAtLTE(time.Now()), + ), + ) case "rate_limited": q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) case "temp_unschedulable": diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 8da30c92..f3e3f745 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -255,6 +255,22 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Equal(service.StatusDisabled, accounts[0].Status) }, }, + { + name: "filter_by_status_active_excludes_rate_limited", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive}) + rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) + err := client.Account.UpdateOneID(rateLimited.ID). + SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + }, + status: service.StatusActive, + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("active-normal", accounts[0].Name) + }, + }, { name: "filter_by_search", setup: func(client *dbent.Client) { From 0b3feb9d4c85df45d83b3165cc35de8cda67d6b3 Mon Sep 17 00:00:00 2001 From: InCerry Date: Tue, 31 Mar 2026 10:33:28 +0800 Subject: [PATCH 007/125] fix(openai): resolve Anthropic compat mapping from normalized model Anthropic compat requests normalize reasoning suffixes before forwarding, but the account mapping step was still using the raw request model. Resolve billing and upstream models from the normalized compat model so explicit account mappings win over fallback defaults. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- backend/internal/service/openai_gateway_messages.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 02efc23b..8c389556 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } originalModel := anthropicReq.Model applyOpenAICompatModelNormalization(&anthropicReq) + normalizedModel := anthropicReq.Model clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses @@ -60,13 +61,14 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) upstreamModel := resolveOpenAIUpstreamModel(billingModel) responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), + zap.String("normalized_model", normalizedModel), zap.String("billing_model", billingModel), zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), From 46bc5ca73b8e8e166321e789f47e24ede50677d2 Mon Sep 17 00:00:00 2001 From: QTom Date: Sat, 28 Mar 2026 21:56:45 +0800 Subject: [PATCH 008/125] =?UTF-8?q?feat(antigravity):=20=E4=BB=A4=E7=89=8C?= =?UTF-8?q?=E5=88=B7=E6=96=B0=E5=A4=B1=E8=B4=A5=E5=8F=8A=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E6=97=B6=E4=B9=9F=E8=AE=BE=E7=BD=AE=E9=9A=90?= =?UTF-8?q?=E7=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - token_refresh: 不可重试错误和重试耗尽两条路径添加 ensureAntigravityPrivacy - admin_service: CreateAccount 为 Antigravity OAuth 账号异步设置隐私 Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/internal/service/admin_service.go | 13 +++++++++++++ backend/internal/service/token_refresh_service.go | 2 ++ 2 files changed, 15 insertions(+) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 88c064f3..50b8d26c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -1587,6 +1588,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Antigravity OAuth 账号:创建后异步设置隐私 + if account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth { + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r) + } + }() + s.EnsureAntigravityPrivacy(context.Background(), account) + }() + } + return account, nil } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index d39095ea..fb2b5210 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -305,6 +305,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } // 刷新失败但 access_token 可能仍有效,尝试设置隐私 s.ensureOpenAIPrivacy(ctx, account) + s.ensureAntigravityPrivacy(ctx, account) return err } @@ -334,6 +335,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc // 刷新失败但 access_token 可能仍有效,尝试设置隐私 s.ensureOpenAIPrivacy(ctx, account) + s.ensureAntigravityPrivacy(ctx, account) // 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试) until := time.Now().Add(tokenRefreshTempUnschedDuration) From aeed2eb9add1727a483080cbe36b7e0e63d0e2e8 Mon Sep 17 00:00:00 2001 From: QTom Date: Fri, 27 Mar 2026 18:02:48 +0800 Subject: [PATCH 009/125] =?UTF-8?q?feat(group-filter):=20=E5=88=86?= =?UTF-8?q?=E7=BB=84=E8=B4=A6=E5=8F=B7=E8=BF=87=E6=BB=A4=E6=8E=A7=E5=88=B6?= =?UTF-8?q?=20=E2=80=94=20require=5Foauth=5Fonly=20+=20require=5Fprivacy?= =?UTF-8?q?=5Fset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 OpenAI/Antigravity/Anthropic/Gemini 分组新增两个布尔控制字段: - require_oauth_only: 创建/更新账号绑定分组时拒绝 apikey 类型加入 - require_privacy_set: 调度选号时跳过 privacy 未成功设置的账号并标记 error 后端:Ent schema 新增字段 + 迁移、Group CRUD 全链路透传、 gateway_service 与 openai_account_scheduler 两套调度路径过滤 前端:创建/编辑表单 toggle 开关(OpenAI/Antigravity/Anthropic/Gemini 平台可见) Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/ent/group.go | 24 +++- backend/ent/group/group.go | 20 +++ backend/ent/group/where.go | 30 ++++ backend/ent/group_create.go | 130 ++++++++++++++++++ backend/ent/group_update.go | 68 +++++++++ backend/ent/migrate/schema.go | 2 + backend/ent/mutation.go | 110 ++++++++++++++- backend/ent/runtime/runtime.go | 10 +- backend/ent/schema/group.go | 6 + .../internal/handler/admin/group_handler.go | 8 ++ backend/internal/handler/dto/mappers.go | 2 + backend/internal/handler/dto/types.go | 4 + backend/internal/repository/api_key_repo.go | 2 + backend/internal/repository/group_repo.go | 4 + backend/internal/server/api_contract_test.go | 2 + backend/internal/service/account.go | 15 ++ backend/internal/service/account_service.go | 26 ++++ backend/internal/service/admin_service.go | 54 ++++++++ .../service/gateway_multiplatform_test.go | 6 +- backend/internal/service/gateway_service.go | 36 +++++ backend/internal/service/group.go | 2 + .../service/openai_account_scheduler.go | 13 ++ .../service/scheduler_snapshot_service.go | 8 ++ .../081_add_group_account_filter.sql | 2 + frontend/src/types/index.ts | 6 + frontend/src/views/admin/GroupsView.vue | 124 +++++++++++++++++ 26 files changed, 708 insertions(+), 6 deletions(-) create mode 100644 backend/migrations/081_add_group_account_filter.sql diff --git a/backend/ent/group.go b/backend/ent/group.go index 3db54a64..fc691a9b 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -80,6 +80,10 @@ type Group struct { SortOrder int `json:"sort_order,omitempty"` // 是否允许 /v1/messages 调度到此 OpenAI 分组 AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` + // 仅允许非 apikey 类型账号关联到此分组 + RequireOauthOnly bool `json:"require_oauth_only,omitempty"` + // 调度时仅允许 privacy 已成功设置的账号 + RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -190,7 +194,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) @@ -425,6 +429,18 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.AllowMessagesDispatch = value.Bool } + case group.FieldRequireOauthOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_oauth_only", values[i]) + } else if value.Valid { + _m.RequireOauthOnly = value.Bool + } + case group.FieldRequirePrivacySet: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_privacy_set", values[i]) + } else if value.Valid { + _m.RequirePrivacySet = value.Bool + } case group.FieldDefaultMappedModel: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i]) @@ -628,6 +644,12 @@ func (_m *Group) String() string { builder.WriteString("allow_messages_dispatch=") builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch)) builder.WriteString(", ") + builder.WriteString("require_oauth_only=") + builder.WriteString(fmt.Sprintf("%v", _m.RequireOauthOnly)) + builder.WriteString(", ") + builder.WriteString("require_privacy_set=") + builder.WriteString(fmt.Sprintf("%v", _m.RequirePrivacySet)) + builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) builder.WriteByte(')') diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 2612b6cf..35222127 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -77,6 +77,10 @@ const ( FieldSortOrder = "sort_order" // FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database. FieldAllowMessagesDispatch = "allow_messages_dispatch" + // FieldRequireOauthOnly holds the string denoting the require_oauth_only field in the database. + FieldRequireOauthOnly = "require_oauth_only" + // FieldRequirePrivacySet holds the string denoting the require_privacy_set field in the database. + FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. @@ -185,6 +189,8 @@ var Columns = []string{ FieldSupportedModelScopes, FieldSortOrder, FieldAllowMessagesDispatch, + FieldRequireOauthOnly, + FieldRequirePrivacySet, FieldDefaultMappedModel, } @@ -255,6 +261,10 @@ var ( DefaultSortOrder int // DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field. DefaultAllowMessagesDispatch bool + // DefaultRequireOauthOnly holds the default value on creation for the "require_oauth_only" field. + DefaultRequireOauthOnly bool + // DefaultRequirePrivacySet holds the default value on creation for the "require_privacy_set" field. + DefaultRequirePrivacySet bool // DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field. DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -414,6 +424,16 @@ func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc() } +// ByRequireOauthOnly orders the results by the require_oauth_only field. +func ByRequireOauthOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequireOauthOnly, opts...).ToFunc() +} + +// ByRequirePrivacySet orders the results by the require_privacy_set field. +func ByRequirePrivacySet(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequirePrivacySet, opts...).ToFunc() +} + // ByDefaultMappedModel orders the results by the default_mapped_model field. func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 5dd8759e..41bd575a 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -200,6 +200,16 @@ func AllowMessagesDispatch(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnly applies equality check predicate on the "require_oauth_only" field. It's identical to RequireOauthOnlyEQ. +func RequireOauthOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySet applies equality check predicate on the "require_privacy_set" field. It's identical to RequirePrivacySetEQ. +func RequirePrivacySet(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ. func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) @@ -1490,6 +1500,26 @@ func AllowMessagesDispatchNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnlyEQ applies the EQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequireOauthOnlyNEQ applies the NEQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySetEQ applies the EQ predicate on the "require_privacy_set" field. +func RequirePrivacySetEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + +// RequirePrivacySetNEQ applies the NEQ predicate on the "require_privacy_set" field. +func RequirePrivacySetNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field. func DefaultMappedModelEQ(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 6db5b974..a635dfd9 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -438,6 +438,34 @@ func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate { return _c } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_c *GroupCreate) SetRequireOauthOnly(v bool) *GroupCreate { + _c.mutation.SetRequireOauthOnly(v) + return _c +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequireOauthOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetRequireOauthOnly(*v) + } + return _c +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_c *GroupCreate) SetRequirePrivacySet(v bool) *GroupCreate { + _c.mutation.SetRequirePrivacySet(v) + return _c +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequirePrivacySet(v *bool) *GroupCreate { + if v != nil { + _c.SetRequirePrivacySet(*v) + } + return _c +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate { _c.mutation.SetDefaultMappedModel(v) @@ -645,6 +673,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultAllowMessagesDispatch _c.mutation.SetAllowMessagesDispatch(v) } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + v := group.DefaultRequireOauthOnly + _c.mutation.SetRequireOauthOnly(v) + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + v := group.DefaultRequirePrivacySet + _c.mutation.SetRequirePrivacySet(v) + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) @@ -722,6 +758,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)} } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + return &ValidationError{Name: "require_oauth_only", err: errors.New(`ent: missing required field "Group.require_oauth_only"`)} + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + return &ValidationError{Name: "require_privacy_set", err: errors.New(`ent: missing required field "Group.require_privacy_set"`)} + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)} } @@ -881,6 +923,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) _node.AllowMessagesDispatch = value } + if value, ok := _c.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + _node.RequireOauthOnly = value + } + if value, ok := _c.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + _node.RequirePrivacySet = value + } if value, ok := _c.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value @@ -1587,6 +1637,30 @@ func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert { return u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsert) SetRequireOauthOnly(v bool) *GroupUpsert { + u.Set(group.FieldRequireOauthOnly, v) + return u +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequireOauthOnly() *GroupUpsert { + u.SetExcluded(group.FieldRequireOauthOnly) + return u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsert) SetRequirePrivacySet(v bool) *GroupUpsert { + u.Set(group.FieldRequirePrivacySet, v) + return u +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequirePrivacySet() *GroupUpsert { + u.SetExcluded(group.FieldRequirePrivacySet) + return u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert { u.Set(group.FieldDefaultMappedModel, v) @@ -2281,6 +2355,34 @@ func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertOne) SetRequireOauthOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequireOauthOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertOne) SetRequirePrivacySet(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequirePrivacySet() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -3143,6 +3245,34 @@ func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertBulk) SetRequireOauthOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequireOauthOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertBulk) SetRequirePrivacySet(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequirePrivacySet() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index b3698596..a9a4b9da 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -639,6 +639,34 @@ func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate { return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdate) SetRequireOauthOnly(v bool) *GroupUpdate { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequireOauthOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdate) SetRequirePrivacySet(v bool) *GroupUpdate { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequirePrivacySet(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate { _u.mutation.SetDefaultMappedModel(v) @@ -1146,6 +1174,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } @@ -2067,6 +2101,34 @@ func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdateOne) SetRequireOauthOnly(v bool) *GroupUpdateOne { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequireOauthOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdateOne) SetRequirePrivacySet(v bool) *GroupUpdateOne { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequirePrivacySet(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne { _u.mutation.SetDefaultMappedModel(v) @@ -2604,6 +2666,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index c472d7e0..6c56f2d0 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -409,6 +409,8 @@ var ( {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false}, + {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, + {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, } // GroupsTable holds the schema information for the "groups" table. diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 42c63c2e..a862209d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8253,6 +8253,8 @@ type GroupMutation struct { sort_order *int addsort_order *int allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool default_mapped_model *string clearedFields map[string]struct{} api_keys map[int64]struct{} @@ -10034,6 +10036,78 @@ func (m *GroupMutation) ResetAllowMessagesDispatch() { m.allow_messages_dispatch = nil } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b +} + +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only + if v == nil { + return + } + return *v, true +} + +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + } + return oldValue.RequireOauthOnly, nil +} + +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b +} + +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set + if v == nil { + return + } + return *v, true +} + +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + } + return oldValue.RequirePrivacySet, nil +} + +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (m *GroupMutation) SetDefaultMappedModel(s string) { m.default_mapped_model = &s @@ -10428,7 +10502,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 34) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10522,6 +10596,12 @@ func (m *GroupMutation) Fields() []string { if m.allow_messages_dispatch != nil { fields = append(fields, group.FieldAllowMessagesDispatch) } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } @@ -10595,6 +10675,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SortOrder() case group.FieldAllowMessagesDispatch: return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() } @@ -10668,6 +10752,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSortOrder(ctx) case group.FieldAllowMessagesDispatch: return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) } @@ -10896,6 +10984,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetAllowMessagesDispatch(v) return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil case group.FieldDefaultMappedModel: v, ok := value.(string) if !ok { @@ -11333,6 +11435,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldAllowMessagesDispatch: m.ResetAllowMessagesDispatch() return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ca95f13f..fd6be291 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -458,8 +458,16 @@ func init() { groupDescAllowMessagesDispatch := groupFields[27].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) + // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. + groupDescRequireOauthOnly := groupFields[28].Descriptor() + // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. + group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) + // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. + groupDescRequirePrivacySet := groupFields[29].Descriptor() + // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. + group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[28].Descriptor() + groupDescDefaultMappedModel := groupFields[30].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 0f5a7b14..fd83bf26 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -153,6 +153,12 @@ func (Group) Fields() []ent.Field { field.Bool("allow_messages_dispatch"). Default(false). Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"), + field.Bool("require_oauth_only"). + Default(false). + Comment("仅允许非 apikey 类型账号关联到此分组"), + field.Bool("require_privacy_set"). + Default(false). + Comment("调度时仅允许 privacy 已成功设置的账号"), field.String("default_mapped_model"). MaxLen(100). Default(""). diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 459fd949..caa27bc3 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -112,6 +112,8 @@ type CreateGroupRequest struct { SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -150,6 +152,8 @@ type UpdateGroupRequest struct { SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -267,6 +271,8 @@ func (h *GroupHandler) Create(c *gin.Context) { SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -320,6 +326,8 @@ func (h *GroupHandler) Update(c *gin.Context) { SupportedModelScopes: req.SupportedModelScopes, SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 0b5448af..a8da92c0 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -181,6 +181,8 @@ func groupFromServiceBase(g *service.Group) Group { FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOAuthOnly, + RequirePrivacySet: g.RequirePrivacySet, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8af6990e..46984044 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -102,6 +102,10 @@ type Group struct { // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + // 账号过滤控制(仅 OpenAI/Antigravity 平台有效) + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 667193a6..ade0d464 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -662,6 +662,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { SupportedModelScopes: g.SupportedModelScopes, SortOrder: g.SortOrder, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOauthOnly, + RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 674c655b..3cfd649b 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -61,6 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetMcpXMLInject(groupIn.MCPXMLInject). SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 设置模型路由配置 @@ -130,6 +132,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetMcpXMLInject(groupIn.MCPXMLInject). SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index ac4e05de..450c3122 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -214,6 +214,8 @@ func TestAPIContracts(t *testing.T) { "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "allow_messages_dispatch": false, + "require_oauth_only": false, + "require_privacy_set": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a1449ffd..53eefda3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -141,6 +141,21 @@ func (a *Account) IsOAuth() bool { return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken } +// IsPrivacySet 检查账号的 privacy 是否已成功设置。 +// OpenAI: privacy_mode == "training_off" +// Antigravity: privacy_mode == "privacy_set" +// 其他平台: 无 privacy 概念,始终返回 true +func (a *Account) IsPrivacySet() bool { + switch a.Platform { + case PlatformOpenAI: + return a.getExtraString("privacy_mode") == PrivacyModeTrainingOff + case PlatformAntigravity: + return a.getExtraString("privacy_mode") == AntigravityPrivacySet + default: + return true + } +} + func (a *Account) IsGemini() bool { return a.Platform == PlatformGemini } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 71d51712..328790a8 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -174,6 +174,19 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( return nil, fmt.Errorf("create account: %w", err) } + // require_oauth_only 检查:apikey 类型账号不可加入限制分组 + if account.Type == AccountTypeAPIKey && len(req.GroupIDs) > 0 { + for _, gid := range req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if len(req.GroupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil { @@ -277,6 +290,19 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount return nil, fmt.Errorf("update account: %w", err) } + // require_oauth_only 检查 + if account.Type == AccountTypeAPIKey && req.GroupIDs != nil { + for _, gid := range *req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if req.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 88c064f3..52c9837c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -162,6 +162,8 @@ type CreateGroupInput struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -201,6 +203,8 @@ type UpdateGroupInput struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -941,12 +945,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn SupportedModelScopes: input.SupportedModelScopes, SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, AllowMessagesDispatch: input.AllowMessagesDispatch, + RequireOAuthOnly: input.RequireOAuthOnly, + RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 如果有需要复制的账号,绑定到新分组 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { @@ -1154,6 +1181,12 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.AllowMessagesDispatch != nil { group.AllowMessagesDispatch = *input.AllowMessagesDispatch } + if input.RequireOAuthOnly != nil { + group.RequireOAuthOnly = *input.RequireOAuthOnly + } + if input.RequirePrivacySet != nil { + group.RequirePrivacySet = *input.RequirePrivacySet + } if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } @@ -1201,6 +1234,27 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 再绑定源分组的账号 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index f28912bb..2d16ad94 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -3139,7 +3139,7 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 0, groupRepo.getByIDLiteCalls) } @@ -3182,7 +3182,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } @@ -3252,7 +3252,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b54f463b..7b7b61ac 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2744,6 +2744,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2815,6 +2821,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2917,6 +2929,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2980,6 +2998,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -3047,6 +3071,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3151,6 +3181,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e17032e0..e0f81a39 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -59,6 +59,8 @@ type Group struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) DefaultMappedModel string CreatedAt time.Time diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 37e7ed2c..6c09e354 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -4,6 +4,7 @@ import ( "container/heap" "context" "errors" + "fmt" "hash/fnv" "math" "sort" @@ -575,6 +576,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( return nil, 0, 0, 0, errors.New("no available OpenAI accounts") } + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if req.GroupID != nil && s.service.schedulerSnapshot != nil { + schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID) + } + filtered := make([]*Account, 0, len(accounts)) loadReq := make([]AccountWithConcurrency, 0, len(accounts)) for i := range accounts { @@ -587,6 +594,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if !account.IsSchedulable() || !account.IsOpenAI() { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() { + _ = s.service.accountRepo.SetError(ctx, account.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { continue } diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4c9540f1..d1330abb 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -152,6 +152,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int return s.accountRepo.GetByID(fallbackCtx, accountID) } +// GetGroupByID 获取分组信息(供调度器使用) +func (s *SchedulerSnapshotService) GetGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if s.groupRepo == nil { + return nil, nil + } + return s.groupRepo.GetByID(ctx, groupID) +} + // UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效) func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error { if s.cache == nil || account == nil { diff --git a/backend/migrations/081_add_group_account_filter.sql b/backend/migrations/081_add_group_account_filter.sql new file mode 100644 index 00000000..0afb21d9 --- /dev/null +++ b/backend/migrations/081_add_group_account_filter.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups ADD COLUMN IF NOT EXISTS require_oauth_only BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE groups ADD COLUMN IF NOT EXISTS require_privacy_set BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index f9425ad0..54ef1a10 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -399,6 +399,8 @@ export interface Group { fallback_group_id_on_invalid_request: number | null // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) allow_messages_dispatch?: boolean + require_oauth_only: boolean + require_privacy_set: boolean created_at: string updated_at: string } @@ -510,6 +512,8 @@ export interface CreateGroupRequest { mcp_xml_inject?: boolean simulate_claude_max_enabled?: boolean supported_model_scopes?: string[] + require_oauth_only?: boolean + require_privacy_set?: boolean // 从指定分组复制账号 copy_accounts_from_group_ids?: number[] } @@ -539,6 +543,8 @@ export interface UpdateGroupRequest { mcp_xml_inject?: boolean simulate_claude_max_enabled?: boolean supported_model_scopes?: string[] + require_oauth_only?: boolean + require_privacy_set?: boolean copy_accounts_from_group_ids?: number[] } diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index a7c1a10d..c7aaf683 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -792,6 +792,61 @@ + +
+

账号过滤控制

+ + +
+
+ +

+ {{ createForm.require_oauth_only ? '已启用 — 排除 API Key 类型账号' : '未启用' }} +

+
+ +
+ + +
+
+ +

+ {{ createForm.require_privacy_set ? '已启用 — Privacy 未设置的账号将被排除' : '未启用' }} +

+
+ +
+
+
+ +
+

账号过滤控制

+ + +
+
+ +

+ {{ editForm.require_oauth_only ? '已启用 — 排除 API Key 类型账号' : '未启用' }} +

+
+ +
+ + +
+
+ +

+ {{ editForm.require_privacy_set ? '已启用 — Privacy 未设置的账号将被排除' : '未启用' }} +

+
+ +
+
+
{ createForm.fallback_group_id = null createForm.fallback_group_id_on_invalid_request = null createForm.allow_messages_dispatch = false + createForm.require_oauth_only = false + createForm.require_privacy_set = false createForm.default_mapped_model = 'gpt-5.4' createForm.supported_model_scopes = ['claude', 'gemini_text', 'gemini_image'] createForm.mcp_xml_inject = true @@ -2539,6 +2657,8 @@ const handleEdit = async (group: AdminGroup) => { editForm.fallback_group_id = group.fallback_group_id editForm.fallback_group_id_on_invalid_request = group.fallback_group_id_on_invalid_request editForm.allow_messages_dispatch = group.allow_messages_dispatch || false + editForm.require_oauth_only = group.require_oauth_only ?? false + editForm.require_privacy_set = group.require_privacy_set ?? false editForm.default_mapped_model = group.default_mapped_model || '' editForm.model_routing_enabled = group.model_routing_enabled || false editForm.supported_model_scopes = group.supported_model_scopes || ['claude', 'gemini_text', 'gemini_image'] @@ -2647,6 +2767,10 @@ watch( createForm.allow_messages_dispatch = false createForm.default_mapped_model = '' } + if (!['openai', 'antigravity', 'anthropic', 'gemini'].includes(newVal)) { + createForm.require_oauth_only = false + createForm.require_privacy_set = false + } } ) From 72e5876c64ee231d5b973546b1a85a262ffd0947 Mon Sep 17 00:00:00 2001 From: QTom Date: Tue, 31 Mar 2026 13:19:40 +0800 Subject: [PATCH 010/125] feat(gateway): Cache-Driven RPM Buffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - buffer 公式从 baseRPM/5 改为 concurrency + maxSessions 保留 baseRPM/5 作为 floor 向后兼容 - 粘性路径 fallback 新增 [StickyCacheMiss] 结构化日志 reason: rpm_red / gate_check / session_limit / wait_queue_full / account_cleared - session_limit 路径跳过 wait queue 重试(RegisterSession 拒绝无副作用) - 典型配置 buffer 从 3 提升至 13,大幅减少高峰期 Prompt Cache Miss Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/internal/service/account.go | 33 +++++++++-- backend/internal/service/account_rpm_test.go | 55 ++++++++++++------ backend/internal/service/gateway_service.go | 61 ++++++++++++++------ 3 files changed, 110 insertions(+), 39 deletions(-) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a1449ffd..5ced20a7 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1727,22 +1727,47 @@ func (a *Account) GetRPMStrategy() string { } // GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 -// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +// Cache-driven: buffer = concurrency + maxSessions(覆盖幽灵窗口 + 稳态会话需求) +// floor = baseRPM / 5(向后兼容 maxSessions=0 且 concurrency=0 场景) func (a *Account) GetRPMStickyBuffer() int { if a.Extra == nil { return 0 } + + // 手动 override 最高优先级 if v, ok := a.Extra["rpm_sticky_buffer"]; ok { val := parseExtraInt(v) if val > 0 { return val } } + base := a.GetBaseRPM() - buffer := base / 5 - if buffer < 1 && base > 0 { - buffer = 1 + if base <= 0 { + return 0 } + + // Cache-driven buffer = concurrency + maxSessions + conc := a.Concurrency + if conc < 0 { + conc = 0 + } + sess := a.GetMaxSessions() + if sess < 0 { + sess = 0 + } + + buffer := conc + sess + + // floor: 向后兼容 + floor := base / 5 + if floor < 1 { + floor = 1 + } + if buffer < floor { + buffer = floor + } + return buffer } diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go index 9d91f3e0..40298263 100644 --- a/backend/internal/service/account_rpm_test.go +++ b/backend/internal/service/account_rpm_test.go @@ -90,28 +90,47 @@ func TestCheckRPMSchedulability(t *testing.T) { func TestGetRPMStickyBuffer(t *testing.T) { tests := []struct { - name string - extra map[string]any - expected int + name string + concurrency int + extra map[string]any + expected int }{ - {"nil extra", nil, 0}, - {"no keys", map[string]any{}, 0}, - {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, - {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, - {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, - {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, - {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, - {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, - {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, - {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, - {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, - {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, - {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, - {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + // 基础退化 + {"nil extra", 0, nil, 0}, + {"no keys", 0, map[string]any{}, 0}, + {"base_rpm=0", 0, map[string]any{"base_rpm": 0}, 0}, + + // 新公式: concurrency + maxSessions, floor = base/5 + {"conc=3 sess=10 → 13", 3, map[string]any{"base_rpm": 15, "max_sessions": 10}, 13}, + {"conc=2 sess=5 → 7", 2, map[string]any{"base_rpm": 10, "max_sessions": 5}, 7}, + {"conc=3 sess=15 → 18", 3, map[string]any{"base_rpm": 30, "max_sessions": 15}, 18}, + + // floor 生效 (conc+sess < base/5) + {"conc=0 sess=0 base=15 → floor 3", 0, map[string]any{"base_rpm": 15}, 3}, + {"conc=0 sess=0 base=10 → floor 2", 0, map[string]any{"base_rpm": 10}, 2}, + {"conc=0 sess=0 base=1 → floor 1", 0, map[string]any{"base_rpm": 1}, 1}, + {"conc=0 sess=0 base=4 → floor 1", 0, map[string]any{"base_rpm": 4}, 1}, + {"conc=1 sess=0 base=15 → floor 3", 1, map[string]any{"base_rpm": 15}, 3}, + + // 手动 override + {"custom buffer=5", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5, "max_sessions": 10}, 5}, + {"custom buffer=0 fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0, "max_sessions": 10}, 13}, + {"custom buffer negative fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1, "max_sessions": 10}, 13}, + {"custom buffer with float", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + + // 负值 clamp + {"negative concurrency clamped", -5, map[string]any{"base_rpm": 15, "max_sessions": 10}, 10}, + {"negative maxSessions clamped", 3, map[string]any{"base_rpm": 15, "max_sessions": -5}, 3}, + + // 高并发低会话 + {"conc=10 sess=5 → 15", 10, map[string]any{"base_rpm": 10, "max_sessions": 5}, 15}, + + // json.Number + {"json.Number base_rpm", 3, map[string]any{"base_rpm": json.Number("10"), "max_sessions": json.Number("5")}, 8}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &Account{Extra: tt.extra} + a := &Account{Concurrency: tt.concurrency, Extra: tt.extra} if got := a.GetRPMStickyBuffer(); got != tt.expected { t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b54f463b..ec14ac62 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1418,19 +1418,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if s.isAccountSchedulableForSelection(stickyAccount) && + var stickyCacheMissReason string + + gatePass := s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) + + if rpmPass { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 + stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { @@ -1444,27 +1449,49 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - // 会话限制已满,继续到负载感知选择 + if stickyCacheMissReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + stickyCacheMissReason = "session_limit" + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + stickyCacheMissReason = "wait_queue_full" } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } else if !gatePass { + stickyCacheMissReason = "gate_check" + } else { + stickyCacheMissReason = "rpm_red" + } + + // 记录粘性缓存未命中的结构化日志 + if stickyCacheMissReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + stickyAccountID, shortSessionHash(sessionHash)) } } } From a025a15f5db3d70138fd56cc71a0defe916978e6 Mon Sep 17 00:00:00 2001 From: shaw Date: Tue, 31 Mar 2026 13:53:49 +0800 Subject: [PATCH 011/125] feat: add refresh button to admin and user dashboard pages --- .../src/components/user/dashboard/UserDashboardCharts.vue | 5 ++++- frontend/src/views/admin/DashboardView.vue | 3 +++ frontend/src/views/user/DashboardView.vue | 5 +++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/frontend/src/components/user/dashboard/UserDashboardCharts.vue b/frontend/src/components/user/dashboard/UserDashboardCharts.vue index 22148592..73e88c3b 100644 --- a/frontend/src/components/user/dashboard/UserDashboardCharts.vue +++ b/frontend/src/components/user/dashboard/UserDashboardCharts.vue @@ -7,6 +7,9 @@ {{ t('dashboard.timeRange') }}:
+
{{ t('dashboard.granularity') }}:
@@ -74,7 +77,7 @@ import { Chart as ChartJS, CategoryScale, LinearScale, PointElement, LineElement ChartJS.register(CategoryScale, LinearScale, PointElement, LineElement, ArcElement, Title, Tooltip, Legend, Filler) const props = defineProps<{ loading: boolean, startDate: string, endDate: string, granularity: string, trend: TrendDataPoint[], models: ModelStat[] }>() -defineEmits(['update:startDate', 'update:endDate', 'update:granularity', 'dateRangeChange', 'granularityChange']) +defineEmits(['update:startDate', 'update:endDate', 'update:granularity', 'dateRangeChange', 'granularityChange', 'refresh']) const { t } = useI18n() const modelData = computed(() => !props.models?.length ? null : { diff --git a/frontend/src/views/admin/DashboardView.vue b/frontend/src/views/admin/DashboardView.vue index 20dd90d2..430b7cee 100644 --- a/frontend/src/views/admin/DashboardView.vue +++ b/frontend/src/views/admin/DashboardView.vue @@ -219,6 +219,9 @@ @change="onDateRangeChange" />
+
{{ t('admin.dashboard.granularity') }}:
@@ -62,6 +66,7 @@ interface Props { type: AccountType planType?: string privacyMode?: string + subscriptionExpiresAt?: string } const props = defineProps() @@ -148,6 +153,22 @@ const planBadgeClass = computed(() => { return typeClass.value }) +// Subscription expiration label (non-free only) +const expiresLabel = computed(() => { + if (!props.subscriptionExpiresAt || !props.planType) return '' + if (props.planType.toLowerCase() === 'free') return '' + try { + const d = new Date(props.subscriptionExpiresAt) + if (isNaN(d.getTime())) return '' + const yyyy = d.getFullYear() + const mm = String(d.getMonth() + 1).padStart(2, '0') + const dd = String(d.getDate()).padStart(2, '0') + return `${t('admin.accounts.subscriptionExpires')} ${yyyy}-${mm}-${dd}` + } catch { + return '' + } +}) + // Privacy badge — shows different states for OpenAI/Antigravity OAuth privacy setting const privacyBadge = computed(() => { if (props.type !== 'oauth' || !props.privacyMode) return null diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index f5267d6a..ba3703ef 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1988,6 +1988,7 @@ export default { privacyAntigravityFailed: 'Privacy setting failed', setPrivacy: 'Set Privacy', subscriptionAbnormal: 'Abnormal', + subscriptionExpires: 'Expires', // Capacity status tooltips capacity: { windowCost: { diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 9581206e..52dd6cdb 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2026,6 +2026,7 @@ export default { privacyAntigravityFailed: '隐私设置失败', setPrivacy: '设置隐私', subscriptionAbnormal: '异常', + subscriptionExpires: '到期', // 容量状态提示 capacity: { windowCost: { diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 35e0fcec..0cc8341c 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -182,7 +182,7 @@