From b7edc3ed82006c28daa9c33b95cd6562c3f7f169 Mon Sep 17 00:00:00 2001 From: shuanbao0 Date: Sat, 11 Apr 2026 20:22:18 +0800 Subject: [PATCH 001/122] =?UTF-8?q?fix(gateway):=20=E5=85=BC=E5=AE=B9=20Cu?= =?UTF-8?q?rsor=20/v1/chat/completions=20=E7=9A=84=20Responses=20API=20bod?= =?UTF-8?q?y?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cursor 云端 (User-Agent: Go-http-client/2.0) 发往 /v1/chat/completions 的 body 使用 Responses API 格式: {"model":"gpt-5.4","input":[{"role":"system","content":"..."}],"stream":true} 原代码用 ChatCompletionsRequest 反序列化,该结构体没有 Input 字段, Cursor 的 input 数组被静默丢弃,ChatCompletionsToResponses 转换后产出 input: null,Codex 上游以 "Invalid type for 'input': expected a string, but got an object" 拒绝请求(上游 typeof null === 'object')。 修复:在 ForwardAsChatCompletions 里用 gjson 检测 body shape,当 input 存在且 messages 缺失时,跳过 Chat→Responses 转换,用 sjson 仅改写 model 字段后原样透传 body。billing 所需的 ServiceTier 和 Reasoning.Effort 通过 gjson 从 raw body 提取,下游 codex OAuth transform 路径保持不变。 测试:新增 openai_cursor_warmup_pipeline_test.go,覆盖 5 个 shape 检测 用例(正向/标准请求不误伤/两字段共存/空 body/JSON 回读)。 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../openai_cursor_warmup_pipeline_test.go | 155 ++++++++++++++++++ .../openai_gateway_chat_completions.go | 60 +++++-- 2 files changed, 203 insertions(+), 12 deletions(-) create mode 100644 backend/internal/service/openai_cursor_warmup_pipeline_test.go diff --git a/backend/internal/service/openai_cursor_warmup_pipeline_test.go b/backend/internal/service/openai_cursor_warmup_pipeline_test.go new file mode 100644 index 00000000..8ade9dbb --- /dev/null +++ b/backend/internal/service/openai_cursor_warmup_pipeline_test.go @@ -0,0 +1,155 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// TestCursorMixedShapeDetection covers the core invariant of the Cursor +// compatibility fix in ForwardAsChatCompletions: when a client POSTs a +// Responses-shaped body (has `input`, no `messages`) to /v1/chat/completions, +// the request must be forwarded as-is with only the `model` field rewritten. +// The raw `input` array (including Cursor's 80KB system prompt) must not be +// discarded or reshaped. +// +// Context: +// +// Before the fix, the handler unmarshaled the body into ChatCompletionsRequest, +// which has no Input field, silently dropping Cursor's input. The subsequent +// conversion produced `input: null`, which Codex upstreams reject with +// "Invalid type for 'input': expected a string, but got an object". +func TestCursorMixedShapeDetection(t *testing.T) { + // Representative Cursor cloud body — shape is what matters, content is + // abridged. Notice: `input` is a Responses-API array, there is no + // `messages` field at all, and `user`/`stream` are at the top level. + cursorBody := []byte(`{ + "user": "85df22e7463ab6c2", + "model": "gpt-5.4", + "stream": true, + "input": [ + {"role":"system","content":"You are GPT-5.4 running as a coding agent."}, + {"role":"user","content":"hello"} + ], + "service_tier": "auto", + "reasoning": {"effort": "high"} + }`) + + // --- Step 1: Shape detection (mirrors ForwardAsChatCompletions) --- + hasMessages := gjson.GetBytes(cursorBody, "messages").Exists() + hasInput := gjson.GetBytes(cursorBody, "input").Exists() + isResponsesShape := !hasMessages && hasInput + + require.True(t, isResponsesShape, + "Cursor body must be detected as Responses-shape (has input, no messages)") + + // --- Step 2: Model rewrite (mirrors the sjson.SetBytes branch) --- + const upstreamModel = "gpt-5.1-codex" + rewritten, err := sjson.SetBytes(cursorBody, "model", upstreamModel) + require.NoError(t, err) + + // --- Step 3: Invariants of the rewritten body --- + + // 3a. model must be rewritten to the upstream target. + assert.Equal(t, upstreamModel, gjson.GetBytes(rewritten, "model").String()) + + // 3b. input array must be preserved verbatim — no reshaping, no nulling. + inputResult := gjson.GetBytes(rewritten, "input") + require.True(t, inputResult.Exists(), "input field must still exist after rewrite") + require.True(t, inputResult.IsArray(), "input must still be an array (not null, not object)") + + items := inputResult.Array() + require.Len(t, items, 2, "both input items must be preserved") + assert.Equal(t, "system", items[0].Get("role").String()) + assert.Equal(t, "You are GPT-5.4 running as a coding agent.", + items[0].Get("content").String()) + assert.Equal(t, "user", items[1].Get("role").String()) + assert.Equal(t, "hello", items[1].Get("content").String()) + + // 3c. ALL other top-level fields must survive intact. + assert.Equal(t, "85df22e7463ab6c2", gjson.GetBytes(rewritten, "user").String()) + assert.Equal(t, true, gjson.GetBytes(rewritten, "stream").Bool()) + assert.Equal(t, "auto", gjson.GetBytes(rewritten, "service_tier").String()) + assert.Equal(t, "high", gjson.GetBytes(rewritten, "reasoning.effort").String()) + + // 3d. Final upstream body must NOT contain the old "input":null pattern. + assert.NotContains(t, string(rewritten), `"input":null`, + "rewritten body must not collapse input to null") +} + +// TestCursorMixedShapeDetection_NormalChatCompletionsUnaffected guards that +// the shape detection does NOT misfire on a standard Chat Completions request +// (one that has a `messages` array). Such requests must fall through to the +// existing ChatCompletionsToResponses conversion path. +func TestCursorMixedShapeDetection_NormalChatCompletionsUnaffected(t *testing.T) { + body := []byte(`{ + "model": "gpt-4o", + "messages": [{"role":"user","content":"hi"}], + "stream": true + }`) + + hasMessages := gjson.GetBytes(body, "messages").Exists() + hasInput := gjson.GetBytes(body, "input").Exists() + isResponsesShape := !hasMessages && hasInput + + assert.False(t, isResponsesShape, + "standard Chat Completions body must NOT be detected as Responses-shape") +} + +// TestCursorMixedShapeDetection_BothFieldsPrefersMessages guards the +// ambiguous case where a client sends both `messages` and `input`. We fall +// through to the normal conversion path (messages wins), since mixing the +// two is almost certainly a client bug and messages is the documented +// Chat Completions contract. +func TestCursorMixedShapeDetection_BothFieldsPrefersMessages(t *testing.T) { + body := []byte(`{ + "model": "gpt-4o", + "messages": [{"role":"user","content":"hi"}], + "input": [{"role":"user","content":"other"}] + }`) + + hasMessages := gjson.GetBytes(body, "messages").Exists() + hasInput := gjson.GetBytes(body, "input").Exists() + isResponsesShape := !hasMessages && hasInput + + assert.False(t, isResponsesShape, + "when both messages and input are present, must not take the Cursor shortcut") +} + +// TestCursorMixedShapeDetection_EmptyBody ensures a body with neither +// messages nor input is NOT taken as Cursor-shape (would hit the normal +// conversion and fail on its own with a clearer error). +func TestCursorMixedShapeDetection_EmptyBody(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","stream":true}`) + + hasMessages := gjson.GetBytes(body, "messages").Exists() + hasInput := gjson.GetBytes(body, "input").Exists() + isResponsesShape := !hasMessages && hasInput + + assert.False(t, isResponsesShape, + "body with neither messages nor input must not be taken as Cursor shape") +} + +// TestCursorMixedShape_JSONRoundtrip ensures the rewritten body is still +// valid JSON and parseable back into a map without surprises — catches +// any encoding drift from sjson. +func TestCursorMixedShape_JSONRoundtrip(t *testing.T) { + cursorBody := []byte(`{"model":"gpt-5.4","stream":true,"input":[{"role":"user","content":"hi"}]}`) + + rewritten, err := sjson.SetBytes(cursorBody, "model", "gpt-5.1-codex") + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(rewritten, &parsed)) + + assert.Equal(t, "gpt-5.1-codex", parsed["model"]) + assert.Equal(t, true, parsed["stream"]) + + inputArr, ok := parsed["input"].([]any) + require.True(t, ok, "input must decode to a Go []any after round-trip") + require.Len(t, inputArr, 1) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 9b3f69bc..c827fd7b 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -16,6 +16,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "go.uber.org/zap" ) @@ -55,13 +57,52 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( compatPromptCacheInjected = promptCacheKey != "" } - // 3. Convert to Responses and forward - // ChatCompletionsToResponses always sets Stream=true (upstream always streams). - responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) - if err != nil { - return nil, fmt.Errorf("convert chat completions to responses: %w", err) + // 3. Build the upstream (Responses API) body. + // + // Cursor compatibility: some clients (notably Cursor cloud) send Responses + // API shaped bodies — `input: [...]` with no `messages` field — to the + // /v1/chat/completions URL. Running those through ChatCompletionsToResponses + // would silently drop Cursor's `input` array (the struct has no Input field) + // and produce `input: null`, which Codex upstreams reject with + // "Invalid type for 'input': expected a string, but got an object". + // + // Detect that shape and forward the raw body as-is, only rewriting `model` + // to the resolved upstream model. The downstream codex OAuth transform will + // still normalize store/stream/instructions/etc. + isResponsesShape := !gjson.GetBytes(body, "messages").Exists() && gjson.GetBytes(body, "input").Exists() + + var ( + responsesReq *apicompat.ResponsesRequest + responsesBody []byte + err error + ) + if isResponsesShape { + responsesBody, err = sjson.SetBytes(body, "model", upstreamModel) + if err != nil { + return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err) + } + // Minimal stub populated from the raw body so downstream billing + // propagation (ServiceTier, ReasoningEffort) keeps working. + responsesReq = &apicompat.ResponsesRequest{ + Model: upstreamModel, + ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(), + } + if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" { + responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort} + } + } else { + // Normal path: convert Chat Completions → Responses. + // ChatCompletionsToResponses always sets Stream=true (upstream always streams). + responsesReq, err = apicompat.ChatCompletionsToResponses(&chatReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + responsesReq.Model = upstreamModel + responsesBody, err = json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } } - responsesReq.Model = upstreamModel logFields := []zap.Field{ zap.Int64("account_id", account.ID), @@ -69,6 +110,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( zap.String("billing_model", billingModel), zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), + zap.Bool("responses_shape", isResponsesShape), } if compatPromptCacheInjected { logFields = append(logFields, @@ -78,12 +120,6 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } logger.L().Debug("openai chat_completions: model mapping applied", logFields...) - // 4. Marshal Responses request body, then apply OAuth codex transform - responsesBody, err := json.Marshal(responsesReq) - if err != nil { - return nil, fmt.Errorf("marshal responses request: %w", err) - } - if account.Type == AccountTypeOAuth { var reqBody map[string]any if err := json.Unmarshal(responsesBody, &reqBody); err != nil { From 422e25c99f2a3a6d16198a2fbeb2eb64cbc912d1 Mon Sep 17 00:00:00 2001 From: shuanbao0 Date: Sat, 11 Apr 2026 22:48:45 +0800 Subject: [PATCH 002/122] =?UTF-8?q?fix(gateway):=20=E5=89=A5=E7=A6=BB=20Cu?= =?UTF-8?q?rsor=20raw=20body=20=E9=80=8F=E4=BC=A0=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E4=B8=AD=20Codex=20=E4=B8=8D=E6=94=AF=E6=8C=81=E7=9A=84=20Resp?= =?UTF-8?q?onses=20API=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在前一个 commit 的 isResponsesShape 短路路径基础上,补充对 Cursor 云端 带过来的、Codex 上游统一不支持的顶层 Responses API 参数的剥离: - prompt_cache_retention - safety_identifier - metadata - stream_options 根因补充:这条 raw-body 透传路径为了保留 Cursor 的 input 数组整体结构, 不再经过 ChatCompletionsRequest 的反序列化过滤,所以这些 Go 结构体里 没有对应字段的参数会被原样发到上游,上游返回: Unsupported parameter: 常规 Chat Completions 转换路径天然通过 ChatCompletionsRequest 丢弃未知字段, 不受影响;此处仅在 isResponsesShape 分支内用 sjson.DeleteBytes 显式过滤, 作用域最小。剥离列表与 openai_gateway_service.go:2034 的 unsupportedFields 语义对齐。 另外在 applyCodexOAuthTransform 的 OAuth 兜底 strip 列表里同步追加 prompt_cache_retention,作为对该函数所有其他 OAuth 调用点的 defense in depth(当前只有 Cursor 路径的短路已在前面剥过,但保留这一层更稳)。 测试: - TestCursorMixedShape_StripsUnsupportedFields — 验证所有 4 个字段都被剥 - TestApplyCodexOAuthTransform_StripsPromptCacheRetention — OAuth 兜底路径 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../service/openai_codex_transform.go | 8 ++++ .../service/openai_codex_transform_test.go | 20 +++++++++ .../openai_cursor_warmup_pipeline_test.go | 44 +++++++++++++++++++ .../openai_gateway_chat_completions.go | 26 +++++++++++ 4 files changed, 98 insertions(+) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 4ec038e0..a266d6a0 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -124,6 +124,14 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact "top_p", "frequency_penalty", "presence_penalty", + // prompt_cache_retention is a newer Responses API parameter (cache TTL). + // The ChatGPT internal Codex endpoint rejects it with + // "Unsupported parameter: prompt_cache_retention". Defense-in-depth + // for any OAuth path that reaches this transform — the Cursor + // Responses-shape short-circuit in ForwardAsChatCompletions strips + // it earlier too, but we keep this line so other OAuth callers are + // equally protected. + "prompt_cache_retention", } { if _, ok := reqBody[key]; ok { delete(reqBody, key) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 889ac615..993ade07 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -481,6 +481,26 @@ func TestExtractSystemMessagesFromInput(t *testing.T) { }) } +// TestApplyCodexOAuthTransform_StripsPromptCacheRetention is a regression +// test: some clients (e.g. Cursor cloud via the Responses-shape compat path) +// send prompt_cache_retention, but the ChatGPT internal Codex endpoint +// rejects it with "Unsupported parameter: prompt_cache_retention". +func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.1", + "prompt_cache_retention": "24h", + "input": []any{ + map[string]any{"role": "user", "content": "hi"}, + }, + } + + applyCodexOAuthTransform(reqBody, false, false) + + _, stillThere := reqBody["prompt_cache_retention"] + require.False(t, stillThere, + "prompt_cache_retention must be stripped before forwarding to Codex upstream") +} + func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.1", diff --git a/backend/internal/service/openai_cursor_warmup_pipeline_test.go b/backend/internal/service/openai_cursor_warmup_pipeline_test.go index 8ade9dbb..19bb13d6 100644 --- a/backend/internal/service/openai_cursor_warmup_pipeline_test.go +++ b/backend/internal/service/openai_cursor_warmup_pipeline_test.go @@ -153,3 +153,47 @@ func TestCursorMixedShape_JSONRoundtrip(t *testing.T) { require.True(t, ok, "input must decode to a Go []any after round-trip") require.Len(t, inputArr, 1) } + +// TestCursorMixedShape_StripsUnsupportedFields mirrors the strip loop in +// ForwardAsChatCompletions (isResponsesShape branch). Cursor cloud sends +// prompt_cache_retention, safety_identifier, metadata and stream_options +// as top-level Responses API parameters, which Codex upstreams reject with +// "Unsupported parameter: ...". The fix must remove them from the raw body +// before it is forwarded, for BOTH OAuth and API Key account types. +func TestCursorMixedShape_StripsUnsupportedFields(t *testing.T) { + cursorBody := []byte(`{ + "model": "gpt-5.4", + "stream": true, + "prompt_cache_retention": "24h", + "safety_identifier": "cursor-user-xyz", + "metadata": {"trace_id":"abc","caller":"cursor"}, + "stream_options": {"include_usage": true}, + "input": [{"role":"user","content":"hi"}] + }`) + + // Sanity: the test fixture contains every field the production code strips. + for _, field := range cursorResponsesUnsupportedFields { + require.True(t, gjson.GetBytes(cursorBody, field).Exists(), + "test fixture must contain %s", field) + } + + // Run the exact same loop as the production code. + result := cursorBody + for _, field := range cursorResponsesUnsupportedFields { + if stripped, err := sjson.DeleteBytes(result, field); err == nil { + result = stripped + } + } + + // All unsupported fields must be gone. + for _, field := range cursorResponsesUnsupportedFields { + assert.False(t, gjson.GetBytes(result, field).Exists(), + "%s must be stripped", field) + } + + // Everything else must survive intact. + assert.Equal(t, "gpt-5.4", gjson.GetBytes(result, "model").String()) + assert.Equal(t, true, gjson.GetBytes(result, "stream").Bool()) + assert.True(t, gjson.GetBytes(result, "input").IsArray()) + assert.Equal(t, "user", gjson.GetBytes(result, "input.0.role").String()) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index c827fd7b..ac7d28a7 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -21,6 +21,22 @@ import ( "go.uber.org/zap" ) +// cursorResponsesUnsupportedFields are top-level Responses API parameters that +// Codex upstreams reject with "Unsupported parameter: ...". They must be +// stripped when forwarding a raw client body through the Responses-shape +// short-circuit in ForwardAsChatCompletions (see isResponsesShape branch). +// The normal Chat Completions → Responses conversion path is unaffected +// because ChatCompletionsRequest has no fields for these parameters — unknown +// fields are dropped naturally by json.Unmarshal. Kept semantically in sync +// with the list in openai_gateway_service.go:2034 used by the /v1/responses +// passthrough path. +var cursorResponsesUnsupportedFields = []string{ + "prompt_cache_retention", + "safety_identifier", + "metadata", + "stream_options", +} + // ForwardAsChatCompletions accepts a Chat Completions request body, converts it // to OpenAI Responses API format, forwards to the OpenAI upstream, and converts // the response back to Chat Completions format. All account types (OAuth and API @@ -81,6 +97,16 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( if err != nil { return nil, fmt.Errorf("rewrite model in responses-shape body: %w", err) } + // Strip Responses API parameters that no Codex upstream accepts. + // Because this branch forwards the raw body (the normal path rebuilds + // it from ChatCompletionsRequest and drops unknown fields naturally), + // we must filter these fields explicitly here — otherwise the upstream + // rejects the request with "Unsupported parameter: ...". + for _, field := range cursorResponsesUnsupportedFields { + if stripped, derr := sjson.DeleteBytes(responsesBody, field); derr == nil { + responsesBody = stripped + } + } // Minimal stub populated from the raw body so downstream billing // propagation (ServiceTier, ReasoningEffort) keeps working. responsesReq = &apicompat.ResponsesRequest{ From 3a11348119669a53bba1ab3ba948216ea67a7ca7 Mon Sep 17 00:00:00 2001 From: qingyuzhang Date: Mon, 13 Apr 2026 06:55:57 +0800 Subject: [PATCH 003/122] fix(frontend): avoid mounting hidden mobile table --- frontend/src/components/common/DataTable.vue | 64 +++++++++++++++++--- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/frontend/src/components/common/DataTable.vue b/frontend/src/components/common/DataTable.vue index 36c7e278..456b1da7 100644 --- a/frontend/src/components/common/DataTable.vue +++ b/frontend/src/components/common/DataTable.vue @@ -1,5 +1,5 @@

- @@ -330,28 +336,31 @@ const onWeeklyModeChange = (e: Event) => {
-
- $ - +
+ $ + +
+

{{ t('admin.accounts.quotaTotalLimitHint') }}

-
From 216bda58da5f52cbffaa8deb79f9891184adc315 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 17:38:33 +0800 Subject: [PATCH 064/122] fix: change quota notify threshold semantics to "remaining quota" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threshold now represents remaining quota instead of usage amount: - Fixed ($): threshold=400, limit=1000 → alert when remaining drops to $400 (i.e., usage reaches $600) - Percentage (%): threshold=30%, limit=1000 → alert when remaining drops to 30% (i.e., usage reaches $700) Also: - Rename 告警阈值 → 提醒阈值 in i18n - Widen type dropdown to w-16 for proper $ / % display --- backend/cmd/server/VERSION | 2 +- .../service/balance_notify_service.go | 15 +++-- .../components/account/QuotaNotifyToggle.vue | 58 ++++++------------- frontend/src/i18n/locales/en.ts | 2 +- frontend/src/i18n/locales/zh.ts | 2 +- 5 files changed, 30 insertions(+), 49 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0c65691..01eddd22 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.31 +0.1.110.34 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index a392a13e..f5abbacc 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -115,13 +115,18 @@ type quotaDim struct { limit float64 } -// resolvedThreshold returns the effective threshold value. -// For percentage type, it computes threshold = limit * percentage / 100. +// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point. +// The threshold represents how much quota REMAINS when the alert fires: +// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400) +// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%) func (d quotaDim) resolvedThreshold() float64 { - if d.thresholdType == thresholdTypePercentage && d.limit > 0 { - return d.limit * d.threshold / 100 + if d.limit <= 0 { + return 0 } - return d.threshold + if d.thresholdType == thresholdTypePercentage { + return d.limit * (1 - d.threshold/100) + } + return d.limit - d.threshold } // buildQuotaDims returns the three quota dimensions for notification checking. diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue index c7583a01..23979638 100644 --- a/frontend/src/components/account/QuotaNotifyToggle.vue +++ b/frontend/src/components/account/QuotaNotifyToggle.vue @@ -1,8 +1,4 @@ diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 5c593946..dcbcf03c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2264,7 +2264,7 @@ export default { quotaLimitAmount: 'Total Limit', quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.', quotaNotify: { - alert: 'Alert Threshold', + alert: 'Alert', enabled: 'Enable Alert', threshold: 'Alert Amount', thresholdPlaceholder: 'Enter percentage', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 141193c1..6dc9311c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2262,7 +2262,7 @@ export default { quotaLimitAmount: '总限额', quotaLimitAmountHint: '累计消费上限,不会自动重置。', quotaNotify: { - alert: '告警阈值', + alert: '提醒阈值', enabled: '启用告警', threshold: '告警金额', thresholdPlaceholder: '输入百分比', From e27335acdd1667c6ae5d35d1bc362412adb98b3f Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:23:20 +0800 Subject: [PATCH 065/122] fix(ui): widen notify type dropdown to show % fully, align quota input widths --- backend/cmd/server/VERSION | 2 +- .../src/components/account/QuotaLimitCard.vue | 161 ++++++------------ .../components/account/QuotaNotifyToggle.vue | 2 +- 3 files changed, 54 insertions(+), 111 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 01eddd22..0a81f94d 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.34 +0.1.110.38 diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index ab109a9a..9051b9be 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -196,171 +196,114 @@ const onWeeklyModeChange = (e: Event) => { -
+
- -
-
- $ - + +
+ {{ t('admin.accounts.quotaDailyLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+ +
+
+ $ +
-
+
-
-

- - +

+ +

- -
-
- $ - +
+ {{ t('admin.accounts.quotaWeeklyLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+
+
+ $ +
-
+
-
-

- - +

+ +

- +
-
- -
-
- $ - +
+ {{ t('admin.accounts.quotaTotalLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+
+
+ $ +
-

{{ t('admin.accounts.quotaTotalLimitHint') }}

+

{{ t('admin.accounts.quotaTotalLimitHint') }}

diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue index 23979638..0548f661 100644 --- a/frontend/src/components/account/QuotaNotifyToggle.vue +++ b/frontend/src/components/account/QuotaNotifyToggle.vue @@ -42,7 +42,7 @@ const emit = defineEmits<{ +

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

+
@@ -3027,6 +3032,7 @@ const form = reactive({ // Balance & quota notification balance_low_notify_enabled: false, balance_low_notify_threshold: 0, + balance_low_notify_recharge_url: '', account_quota_notify_enabled: false, account_quota_notify_emails: [] as NotifyEmailEntry[] }) @@ -3598,6 +3604,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, + balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || '', account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From 48b6c4811f2f8b58933e30c91687acb027aeedf4 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:44:36 +0800 Subject: [PATCH 067/122] fix(notify): auto-fill recharge URL with current origin when empty --- backend/cmd/server/VERSION | 2 +- frontend/src/views/admin/SettingsView.vue | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index ee13da53..9e3ea640 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.39 +0.1.110.40 diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index ef67ce22..57f6b35e 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2707,7 +2707,7 @@
- +

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

@@ -3262,6 +3262,8 @@ const addQuotaNotifyEmail = () => { form.account_quota_notify_emails.push({ email: '', disabled: false, verified: true }) } +const currentOrigin = typeof window !== 'undefined' ? window.location.origin : '' + // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -3604,7 +3606,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, - balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || '', + balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || currentOrigin, account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From f571d8ffad4ac8f3c20d44a70221a0d1f7211d28 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:52:02 +0800 Subject: [PATCH 068/122] fix(notify): write back auto-filled recharge URL to form on save --- backend/cmd/server/VERSION | 2 +- frontend/src/views/admin/SettingsView.vue | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 9e3ea640..a69c1172 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.40 +0.1.110.41 diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 57f6b35e..3ef1c0ba 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3606,7 +3606,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, - balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || currentOrigin, + balance_low_notify_recharge_url: (form.balance_low_notify_recharge_url = form.balance_low_notify_recharge_url || currentOrigin), account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From 6e9146e746a8743ad078230aedaabae85ef15e25 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 19:02:40 +0800 Subject: [PATCH 069/122] fix(notify): add recharge URL to admin settings GET response --- backend/cmd/server/VERSION | 2 +- backend/internal/handler/admin/setting_handler.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index a69c1172..ab6fbb6e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.41 +0.1.110.42 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e31eb134..bc6d183c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -178,6 +178,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), PaymentEnabled: paymentCfg.Enabled, From a43da6225449c68f486b796139a4144c9fbe24fc Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 20:01:25 +0800 Subject: [PATCH 070/122] fix(accounts): unify modal width, add notify props to create, fix quota layout - EditAccountModal width changed from "normal" to "wide" (match CreateAccountModal) - CreateAccountModal now passes all quota notify props to QuotaLimitCard - QuotaLimitCard: when global notify disabled, hide title row, input takes full width - Quota alert email: show remaining quota + threshold (fixed/$, percentage/%) instead of usage trigger point --- backend/cmd/server/VERSION | 2 +- .../service/balance_notify_service.go | 35 +++++++---- .../components/account/CreateAccountModal.vue | 61 +++++++++++++++++++ .../components/account/EditAccountModal.vue | 2 +- .../src/components/account/QuotaLimitCard.vue | 23 ++++--- 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index ab6fbb6e..76129f5c 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.42 +0.1.110.44 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 3951e88f..5191a26e 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -257,7 +257,7 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account slog.Error("panic in quota notification", "recover", r) } }() - s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim.name, newUsed, dim.limit, effectiveThreshold, siteName) + s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName) }() } @@ -384,15 +384,25 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam } // sendQuotaAlertEmails sends quota alert notification to admin emails. -func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform, dimension string, used, limit, threshold float64, siteName string) { - dimLabel := quotaDimLabels[dimension] +func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) { + dimLabel := quotaDimLabels[dim.name] if dimLabel == "" { - dimLabel = dimension + dimLabel = dim.name + } + + // Format the remaining-based threshold for display + thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold) + if dim.thresholdType == thresholdTypePercentage { + thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold) + } + remaining := dim.limit - used + if remaining < 0 { + remaining = 0 } subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) - body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName)) - s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) + body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName)) + s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name) } // sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection. @@ -440,7 +450,7 @@ const balanceLowEmailTemplate = ` ` // quotaAlertEmailTemplate is the HTML template for account quota alert notifications. -// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, threshold. +// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay. const quotaAlertEmailTemplate = ` @@ -469,10 +479,11 @@ const quotaAlertEmailTemplate = `
维度 / Dimension%s
已使用 / Used$%.2f
限额 / Limit%s
-
告警阈值 / Threshold$%.2f
+
剩余额度 / Remaining$%.2f
+
提醒阈值 / Alert Threshold%s
-

账号配额用量已达到告警阈值,请及时关注。

-

Account quota usage has reached the alert threshold.

+

账号剩余额度已低于提醒阈值,请及时关注。

+

Account remaining quota has fallen below the alert threshold.

@@ -490,11 +501,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance } // buildQuotaAlertEmailBody builds HTML email for account quota alert. -func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, threshold float64, siteName string) string { +func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string { limitStr := fmt.Sprintf("$%.2f", limit) if limit <= 0 { limitStr = "无限制 / Unlimited" } - return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, threshold) + return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay) } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index a1496fa8..ba7bad51 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1493,6 +1493,15 @@ :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" + :quotaNotifyDailyEnabled="quotaNotifyDailyEnabled" + :quotaNotifyDailyThreshold="quotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType" + :quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled" + :quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType" + :quotaNotifyTotalEnabled="quotaNotifyTotalEnabled" + :quotaNotifyTotalThreshold="quotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType" :dailyResetMode="editDailyResetMode" :dailyResetHour="editDailyResetHour" :weeklyResetMode="editWeeklyResetMode" @@ -1502,6 +1511,15 @@ @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:quotaNotifyDailyEnabled="quotaNotifyDailyEnabled = $event" + @update:quotaNotifyDailyThreshold="quotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType = $event" + @update:quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled = $event" + @update:quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType = $event" + @update:quotaNotifyTotalEnabled="quotaNotifyTotalEnabled = $event" + @update:quotaNotifyTotalThreshold="quotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType = $event" @update:dailyResetMode="editDailyResetMode = $event" @update:dailyResetHour="editDailyResetHour = $event" @update:weeklyResetMode="editWeeklyResetMode = $event" @@ -1527,6 +1545,15 @@ :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" + :quotaNotifyDailyEnabled="quotaNotifyDailyEnabled" + :quotaNotifyDailyThreshold="quotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType" + :quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled" + :quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType" + :quotaNotifyTotalEnabled="quotaNotifyTotalEnabled" + :quotaNotifyTotalThreshold="quotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType" :dailyResetMode="editDailyResetMode" :dailyResetHour="editDailyResetHour" :weeklyResetMode="editWeeklyResetMode" @@ -1536,6 +1563,15 @@ @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:quotaNotifyDailyEnabled="quotaNotifyDailyEnabled = $event" + @update:quotaNotifyDailyThreshold="quotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType = $event" + @update:quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled = $event" + @update:quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType = $event" + @update:quotaNotifyTotalEnabled="quotaNotifyTotalEnabled = $event" + @update:quotaNotifyTotalThreshold="quotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType = $event" @update:dailyResetMode="editDailyResetMode = $event" @update:dailyResetHour="editDailyResetHour = $event" @update:weeklyResetMode="editWeeklyResetMode = $event" @@ -3041,6 +3077,15 @@ const anthropicPassthroughEnabled = ref(false) const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) const quotaNotifyGlobalEnabled = ref(false) +const quotaNotifyDailyEnabled = ref(null) +const quotaNotifyDailyThreshold = ref(null) +const quotaNotifyDailyThresholdType = ref(null) +const quotaNotifyWeeklyEnabled = ref(null) +const quotaNotifyWeeklyThreshold = ref(null) +const quotaNotifyWeeklyThresholdType = ref(null) +const quotaNotifyTotalEnabled = ref(null) +const quotaNotifyTotalThreshold = ref(null) +const quotaNotifyTotalThresholdType = ref(null) // Load global feature states once adminAPI.settings.getWebSearchEmulationConfig().then(cfg => { @@ -4153,6 +4198,22 @@ const createAccountAndFinish = async ( if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' } + // Quota notify config + if (quotaNotifyDailyEnabled.value) { + quotaExtra.quota_notify_daily_enabled = true + if (quotaNotifyDailyThreshold.value != null) quotaExtra.quota_notify_daily_threshold = quotaNotifyDailyThreshold.value + quotaExtra.quota_notify_daily_threshold_type = quotaNotifyDailyThresholdType.value || 'fixed' + } + if (quotaNotifyWeeklyEnabled.value) { + quotaExtra.quota_notify_weekly_enabled = true + if (quotaNotifyWeeklyThreshold.value != null) quotaExtra.quota_notify_weekly_threshold = quotaNotifyWeeklyThreshold.value + quotaExtra.quota_notify_weekly_threshold_type = quotaNotifyWeeklyThresholdType.value || 'fixed' + } + if (quotaNotifyTotalEnabled.value) { + quotaExtra.quota_notify_total_enabled = true + if (quotaNotifyTotalThreshold.value != null) quotaExtra.quota_notify_total_threshold = quotaNotifyTotalThreshold.value + quotaExtra.quota_notify_total_threshold_type = quotaNotifyTotalThresholdType.value || 'fixed' + } if (Object.keys(quotaExtra).length > 0) { finalExtra = quotaExtra } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 613738d2..92761b35 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -2,7 +2,7 @@
{
- -
+ +
{{ t('admin.accounts.quotaDailyLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
@@ -238,12 +239,13 @@ const onWeeklyModeChange = (e: Event) => {
-
+
{{ t('admin.accounts.quotaWeeklyLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
@@ -287,12 +289,13 @@ const onWeeklyModeChange = (e: Event) => {
-
+
{{ t('admin.accounts.quotaTotalLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
From ca673f98995b286c35b19b659088c5970b4eea54 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 20:35:38 +0800 Subject: [PATCH 071/122] test: add 66 unit tests for balance/quota notify + plan validation balance_notify_service_test.go (27 tests): - resolveBalanceThreshold: fixed/percentage/zero recharged/empty type - quotaDim.resolvedThreshold: fixed normal/exceed/equal limit, percentage 0/30/100/>100, zero/negative limit - sanitizeEmailHeader: CRLF/CR/LF/clean/empty/multiple newlines - buildQuotaDims / buildQuotaDimsFromState: all dimensions, empty extra, state-vs-account precedence - collectBalanceNotifyRecipients: empty, filter disabled/unverified, case-insensitive dedup, skip empty, trim balance_notify_check_test.go (16 tests): - CheckBalanceAfterDeduction guard clauses: nil user/disabled/global-off/threshold=0/user-override/no-crossing - CheckAccountQuotaAfterIncrement guards: nil account/zero cost/negative cost/global-disabled - getBalanceNotifyConfig: all fields, disabled, invalid threshold - isAccountQuotaNotifyEnabled: missing/false/true - getSiteName: default fallback + configured balance_notify_email_body_test.go (10 tests): - Guards against fmt.Sprintf arg-count mismatches in email templates - Verifies HTML escaping of recharge URL - Verifies CSS %% escape produces literal % in output - Verifies unlimited/percentage/over-quota display branches payment_config_plans_validation_test.go (13 tests): - validatePlanRequired: all 5 validation branches + whitespace handling --- .../service/balance_notify_check_test.go | 180 +++++++++++ .../service/balance_notify_email_body_test.go | 147 +++++++++ .../service/balance_notify_service_test.go | 280 ++++++++++++++++++ .../payment_config_plans_validation_test.go | 89 ++++++ 4 files changed, 696 insertions(+) create mode 100644 backend/internal/service/balance_notify_check_test.go create mode 100644 backend/internal/service/balance_notify_email_body_test.go create mode 100644 backend/internal/service/balance_notify_service_test.go create mode 100644 backend/internal/service/payment_config_plans_validation_test.go diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go new file mode 100644 index 00000000..955f3129 --- /dev/null +++ b/backend/internal/service/balance_notify_check_test.go @@ -0,0 +1,180 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an +// in-memory settings repo and a non-nil emailService so that the guard-clause +// nil-checks pass. The emailService is intentionally minimal — tests must +// avoid crossing scenarios that would actually dispatch emails. +func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) { + repo := newMockSettingRepo() + // EmailService is a concrete type; construct with the same repo so that + // any accidental fallback reads still succeed. Tests should not trigger a + // crossing that reaches SendEmail. + email := NewEmailService(repo, nil) + return NewBalanceNotifyService(email, repo, nil), repo +} + +// ---------- guard clauses ---------- + +func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50) +} + +func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: false} + // Even with a crossing, disabled flag short-circuits. + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "0" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default + customThreshold := 5.0 + u := &User{ + ID: 1, + BalanceNotifyEnabled: true, + BalanceNotifyThreshold: &customThreshold, + } + // User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not + // cross 5, so nothing fires (verified by absence of panic). + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: true} + + // 100 -> 95, both remain above threshold=10, no crossing. + s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5) + // 5 -> 3, both already below threshold, no crossing (only fires on first + // cross from above-to-below). + s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2) +} + +// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ---------- + +func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil) +} + +func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil) +} + +func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil) +} + +func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + a := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_notify_daily_enabled": true, + "quota_notify_daily_threshold": 100.0, + "quota_daily_limit": 1000.0, + "quota_daily_used": 950.0, + }, + } + // Global disabled → no processing even if a dim would cross. + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil) +} + +// ---------- sanity: internal helpers still work ---------- + +func TestGetBalanceNotifyConfig_AllFields(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5" + repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay" + + enabled, threshold, url := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 12.5, threshold) + require.Equal(t, "https://example.com/pay", url) +} + +func TestGetBalanceNotifyConfig_Disabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + + enabled, _, _ := s.getBalanceNotifyConfig(context.Background()) + require.False(t, enabled) +} + +func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number" + + enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 0.0, threshold) +} + +func TestIsAccountQuotaNotifyEnabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + + // Missing key → false + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "false" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "true" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true" + require.True(t, s.isAccountQuotaNotifyEnabled(context.Background())) +} + +func TestGetSiteName_FallsBackToDefault(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + name := s.getSiteName(context.Background()) + require.Equal(t, defaultSiteName, name) +} + +func TestGetSiteName_Configured(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeySiteName] = "My Site" + require.Equal(t, "My Site", s.getSiteName(context.Background())) +} diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go new file mode 100644 index 00000000..9baf164e --- /dev/null +++ b/backend/internal/service/balance_notify_email_body_test.go @@ -0,0 +1,147 @@ +//go:build unit + +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// These tests guard against fmt.Sprintf arg-count mismatches in the email +// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in +// the output, which these assertions will catch. + +// ---------- buildBalanceLowEmailBody ---------- + +func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "") + + // All substituted values should appear in the output. + require.Contains(t, body, "MySite") + require.Contains(t, body, "Alice") + require.Contains(t, body, "$3.14") + require.Contains(t, body, "$10.00") + + // No fmt.Sprintf format error markers. + require.NotContains(t, body, "%!") + require.NotContains(t, body, "MISSING") + require.NotContains(t, body, "EXTRA") +} + +func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay") + + // The recharge anchor element should appear with the URL. + require.Contains(t, body, `href="https://example.com/pay"`) + require.Contains(t, body, "立即充值") + require.NotContains(t, body, "%!") +} + +func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) { + s := &BalanceNotifyService{} + // Try a URL with characters that need HTML escaping. + body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b= + + diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 5f0c7c2c..77e437a8 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -1,7 +1,7 @@ diff --git a/frontend/src/composables/useQuotaNotifyState.ts b/frontend/src/composables/useQuotaNotifyState.ts new file mode 100644 index 00000000..1c6705d3 --- /dev/null +++ b/frontend/src/composables/useQuotaNotifyState.ts @@ -0,0 +1,69 @@ +import { reactive, ref } from 'vue' +import { adminAPI } from '@/api/admin' +import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account' + +export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const +export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number] + +interface DimState { + enabled: boolean | null + threshold: number | null + thresholdType: string | null +} + +export function useQuotaNotifyState() { + const globalEnabled = ref(false) + const state = reactive>({ + daily: { enabled: null, threshold: null, thresholdType: null }, + weekly: { enabled: null, threshold: null, thresholdType: null }, + total: { enabled: null, threshold: null, thresholdType: null }, + }) + + function loadGlobalState() { + adminAPI.settings + .getSettings() + .then((settings) => { + globalEnabled.value = settings.account_quota_notify_enabled === true + }) + .catch(() => { + globalEnabled.value = false + }) + } + + function loadFromExtra(extra: Record | null | undefined) { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null + state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null + state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null + } + } + + function writeToExtra(extra: Record, mode: 'create' | 'update') { + for (const d of QUOTA_NOTIFY_DIMS) { + const s = state[d] + if (s.enabled) { + extra[`quota_notify_${d}_enabled`] = true + if (s.threshold != null) { + extra[`quota_notify_${d}_threshold`] = s.threshold + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_threshold`] + } + extra[`quota_notify_${d}_threshold_type`] = s.thresholdType || QUOTA_THRESHOLD_TYPE_FIXED + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_enabled`] + delete extra[`quota_notify_${d}_threshold`] + delete extra[`quota_notify_${d}_threshold_type`] + } + } + } + + function reset() { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = null + state[d].threshold = null + state[d].thresholdType = null + } + } + + return { globalEnabled, state, loadGlobalState, loadFromExtra, writeToExtra, reset } +} diff --git a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue index 639b4f66..28b82da5 100644 --- a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue @@ -29,7 +29,7 @@ @@ -67,86 +67,14 @@
- - -
-
- - -
-
- - -
-
- - -
-
- -
-
-
{{ t('payment.admin.dailyLimit') }}: {{ selectedGroupInfo.daily_limit_usd != null ? '$' + selectedGroupInfo.daily_limit_usd : t('payment.admin.unlimited') }}
-
{{ t('payment.admin.weeklyLimit') }}: {{ selectedGroupInfo.weekly_limit_usd != null ? '$' + selectedGroupInfo.weekly_limit_usd : t('payment.admin.unlimited') }}
-
{{ t('payment.admin.monthlyLimit') }}: {{ selectedGroupInfo.monthly_limit_usd != null ? '$' + selectedGroupInfo.monthly_limit_usd : t('payment.admin.unlimited') }}
-
-
- -
-
-
-
-
-
-
-
-
-

{{ t('payment.admin.featuresHint') }}

-
-
- - -
- - - + From 74f8a30f861f2b5072f5916265abbf61d3448b12 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 23:35:59 +0800 Subject: [PATCH 076/122] fix: address audit findings for websearch, email verification, and pricing - Fix websearch provider failover: proxy error from provider-specific proxy now continues to next provider instead of aborting the entire loop - Fix SMTP failure locking users out: send email first, then write cache and increment rate counter - Fix notify email cache key case sensitivity: normalize to lowercase - Add OriginalPrice validation to validatePlanPatch and validatePlanRequired - Add empty scope validation for channel pricing rules (group_ids/account_ids) - Add platform color to account search dropdown in channel pricing rules --- .../internal/handler/admin/channel_handler.go | 11 +++ backend/internal/pkg/websearch/manager.go | 13 +++- backend/internal/repository/email_cache.go | 5 +- .../internal/service/payment_config_plans.go | 10 ++- .../payment_config_plans_validation_test.go | 75 ++++++++++++++----- backend/internal/service/user_service.go | 8 +- frontend/src/views/admin/ChannelsView.vue | 7 +- 7 files changed, 103 insertions(+), 26 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 2d4cd56a..ee76a750 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "strconv" "strings" @@ -351,6 +352,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) @@ -409,6 +415,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { if req.AccountStatsPricingRules != nil { statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules)) for i, r := range *req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index ae0683ad..27592459 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -111,9 +111,18 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) } if isProxyError(err) { m.markProxyUnavailable(ctx, cfg, req.ProxyURL) - slog.Warn("websearch: proxy error, marking unavailable", + if req.ProxyURL != "" { + // Account-level proxy is shared by all providers — no point + // trying others with the same broken proxy; signal account switch. + slog.Warn("websearch: account proxy error, aborting failover", + "provider", cfg.Type, "error", err) + return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + } + // Provider-specific proxy failed — try the next provider which + // may use a different (or no) proxy. + slog.Warn("websearch: provider proxy error, trying next provider", "provider", cfg.Type, "error", err) - return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + continue } slog.Warn("websearch: provider search failed", "provider", cfg.Type, "error", err) diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index ed903e0d..1356163d 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -24,8 +25,10 @@ func verifyCodeKey(email string) string { } // notifyVerifyKey generates the Redis key for notify email verification code. +// Email is lowercased to prevent case-sensitive key mismatch (the business layer +// uses strings.EqualFold for comparison). func notifyVerifyKey(email string) string { - return notifyVerifyKeyPrefix + email + return notifyVerifyKeyPrefix + strings.ToLower(email) } // passwordResetKey generates the Redis key for password reset token. diff --git a/backend/internal/service/payment_config_plans.go b/backend/internal/service/payment_config_plans.go index 8a5e1924..6753071d 100644 --- a/backend/internal/service/payment_config_plans.go +++ b/backend/internal/service/payment_config_plans.go @@ -12,7 +12,7 @@ import ( ) // validatePlanRequired checks that all required fields for a plan are provided. -func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string) error { +func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error { if strings.TrimSpace(name) == "" { return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required") } @@ -28,6 +28,9 @@ func validatePlanRequired(name string, groupID int64, price float64, validityDay if strings.TrimSpace(validityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if originalPrice != nil && *originalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -48,6 +51,9 @@ func validatePlanPatch(req UpdatePlanRequest) error { if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if req.OriginalPrice != nil && *req.OriginalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -115,7 +121,7 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S } func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { - if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit); err != nil { + if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil { return nil, err } b := s.entClient.SubscriptionPlan.Create(). diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index bc9c0048..9a2d8716 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -9,81 +9,122 @@ import ( ) func TestValidatePlanRequired_AllValid(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "days") + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil) require.NoError(t, err) } func TestValidatePlanRequired_EmptyName(t *testing.T) { - err := validatePlanRequired("", 1, 9.99, 30, "days") + err := validatePlanRequired("", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_WhitespaceName(t *testing.T) { - err := validatePlanRequired(" ", 1, 9.99, 30, "days") + err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_ZeroGroupID(t *testing.T) { - err := validatePlanRequired("Pro", 0, 9.99, 30, "days") + err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_NegativeGroupID(t *testing.T) { - err := validatePlanRequired("Pro", -1, 9.99, 30, "days") + err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_ZeroPrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, 0, 30, "days") + err := validatePlanRequired("Pro", 1, 0, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_NegativePrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, -5, 30, "days") + err := validatePlanRequired("Pro", 1, -5, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 0, "days") + err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, -7, "days") + err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "") + err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, " ") + err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) { - // When multiple fields are invalid, name should be reported first - // (follows the order of checks in the function). - err := validatePlanRequired("", 0, 0, 0, "") + err := validatePlanRequired("", 0, 0, 0, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_TrimmedValidName(t *testing.T) { - // Whitespace-surrounded but non-empty name is accepted (trimmed check only - // rejects pure whitespace). - err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days") + err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil) + require.NoError(t, err) +} + +func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) { + neg := -10.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero) + require.NoError(t, err) +} + +func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) { + op := 19.99 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op) + require.NoError(t, err) +} + +// --- validatePlanPatch tests --- + +func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) { + neg := -5.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg}) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) { + op := 29.99 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) require.NoError(t, err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0da73762..7602d162 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -291,6 +291,12 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema return fmt.Errorf("generate code: %w", err) } + // Send email first — if SMTP fails, don't write cache or increment counters, + // so the user is not locked out by cooldown/rate-limit for a code they never received. + if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil { + return err + } + if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil { return err } @@ -300,7 +306,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err) } - return s.sendNotifyVerifyEmail(ctx, emailService, email, code) + return nil } // checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit. diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 2ca1141d..60704b65 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -511,7 +511,7 @@ :class="{ 'opacity-50': rule.account_ids.includes(account.id) }" :disabled="rule.account_ids.includes(account.id)" > - {{ account.name }} + {{ account.name }} #{{ account.id }}
@@ -595,6 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' +import { platformTextClass } from '@/utils/platformColors' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -911,7 +912,7 @@ function getGroupNameById(groupId: number): string { } // ── Account search for pricing rules ── -interface SimpleAccount { id: number; name: string } +interface SimpleAccount { id: number; name: string; platform: string } const ruleAccountSearchKeyword = ref>({}) const ruleAccountSearchResults = ref>({}) @@ -924,7 +925,7 @@ const ruleAccountSearchRunner = useKeyedDebouncedSearch({ search: async (keyword, { key, signal }) => { const platform = key.split('-')[0] const res = await adminAPI.accounts.list(1, 20, { platform, search: keyword }, { signal }) - return res.items.map(a => ({ id: a.id, name: a.name })) + return res.items.map(a => ({ id: a.id, name: a.name, platform: a.platform })) }, onSuccess: (key, result) => { ruleAccountSearchResults.value[key] = result }, onError: (key) => { ruleAccountSearchResults.value[key] = [] }, From a9880ee7b92365482154f986ea4b4d5256ffd7ec Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 00:26:20 +0800 Subject: [PATCH 077/122] =?UTF-8?q?fix:=20round-2=20audit=20fixes=20?= =?UTF-8?q?=E2=80=94=20security,=20code=20quality,=20and=20UI=20improvemen?= =?UTF-8?q?ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Security (HIGH): - Normalize all Redis cache keys to lowercase (verifyCode, passwordReset) - Fix verify code TTL renewal on failed attempts: use remaining TTL via ExpiresAt field instead of resetting to full 15-minute window - Add 3 missing fields to diffSettings audit log (promo_code, invitation_code, custom_endpoints) Code quality (MEDIUM): - Extract filterVerifiedEmails shared helper (balance_notify_service.go) - Add Pricing array non-empty validation for channel pricing rules - Add platform token semantics comment in gateway_service.go - Complete validatePlanPatch test coverage (+10 test cases) - Replace string types with QuotaThresholdType/QuotaResetMode across frontend - Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView - Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss UI improvements: - Reorder cost tooltip: user billing above separator, account billing below - Add NaN guard to accountBilled function - Move timezone selector inline into reset-mode row (no longer standalone) --- .../internal/handler/admin/channel_handler.go | 10 + .../internal/handler/admin/setting_handler.go | 9 + backend/internal/repository/email_cache.go | 7 +- .../service/balance_notify_service.go | 44 +- backend/internal/service/email_service.go | 8 +- backend/internal/service/gateway_service.go | 591 +++++++++++++----- .../payment_config_plans_validation_test.go | 63 ++ backend/internal/service/user_service.go | 15 +- .../components/account/QuotaDimensionRow.vue | 28 +- .../src/components/account/QuotaLimitCard.vue | 48 +- .../components/account/QuotaNotifyToggle.vue | 8 +- .../src/components/admin/usage/UsageTable.vue | 12 +- .../src/composables/useQuotaNotifyState.ts | 6 +- frontend/src/constants/account.ts | 5 + frontend/src/views/admin/ChannelsView.vue | 42 +- 15 files changed, 605 insertions(+), 291 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index ee76a750..1a328551 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) return } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) @@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) return } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0c1606ea..2324cc70 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { changed = append(changed, "registration_email_suffix_whitelist") } + if before.PromoCodeEnabled != after.PromoCodeEnabled { + changed = append(changed, "promo_code_enabled") + } + if before.InvitationCodeEnabled != after.InvitationCodeEnabled { + changed = append(changed, "invitation_code_enabled") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } @@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.CustomMenuItems != after.CustomMenuItems { changed = append(changed, "custom_menu_items") } + if before.CustomEndpoints != after.CustomEndpoints { + changed = append(changed, "custom_endpoints") + } if before.EnableFingerprintUnification != after.EnableFingerprintUnification { changed = append(changed, "enable_fingerprint_unification") } diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 1356163d..0eb6bef1 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -20,8 +20,9 @@ const ( ) // verifyCodeKey generates the Redis key for email verification code. +// Email is lowercased for case-insensitive consistency. func verifyCodeKey(email string) string { - return verifyCodeKeyPrefix + email + return verifyCodeKeyPrefix + strings.ToLower(email) } // notifyVerifyKey generates the Redis key for notify email verification code. @@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string { // passwordResetKey generates the Redis key for password reset token. func passwordResetKey(email string) string { - return passwordResetKeyPrefix + email + return passwordResetKeyPrefix + strings.ToLower(email) } // passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. func passwordResetSentAtKey(email string) string { - return passwordResetSentAtKeyPrefix + email + return passwordResetSentAtKeyPrefix + strings.ToLower(email) } type emailCache struct { diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 9a75d6be..5e9afcc8 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -283,6 +283,20 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) return nil } + return filterVerifiedEmails(entries) +} + +// getSiteName reads site name from settings with fallback. +func (s *BalanceNotifyService) getSiteName(ctx context.Context) string { + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || name == "" { + return defaultSiteName + } + return name +} + +// filterVerifiedEmails returns deduplicated, non-disabled, verified emails. +func filterVerifiedEmails(entries []NotifyEmailEntry) []string { var recipients []string seen := make(map[string]bool) for _, entry := range entries { @@ -303,38 +317,10 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) return recipients } -// getSiteName reads site name from settings with fallback. -func (s *BalanceNotifyService) getSiteName(ctx context.Context) string { - name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) - if err != nil || name == "" { - return defaultSiteName - } - return name -} - // collectBalanceNotifyRecipients returns verified, non-disabled email recipients. // Only emails with verified=true and disabled=false are included. func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string { - var recipients []string - seen := make(map[string]bool) - - for _, entry := range user.BalanceNotifyExtraEmails { - if entry.Disabled || !entry.Verified { - continue - } - email := strings.TrimSpace(entry.Email) - if email == "" { - continue - } - lower := strings.ToLower(email) - if seen[lower] { - continue - } - seen[lower] = true - recipients = append(recipients, email) - } - - return recipients + return filterVerifiedEmails(user.BalanceNotifyExtraEmails) } // sendEmails sends an email to all recipients with shared timeout and error logging. diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 50d324f2..a94e0dde 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -55,6 +55,7 @@ type VerificationCodeData struct { Code string Attempts int CreatedAt time.Time + ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts } // PasswordResetTokenData represents password reset token data @@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin Code: code, Attempts: 0, CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(verifyCodeTTL), } if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) @@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 (constant-time comparison to prevent timing attacks) if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ - if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + remaining := time.Until(data.ExpiresAt) + if remaining <= 0 { + return ErrInvalidVerifyCode + } + if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil { slog.Error("failed to update verification attempt count", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b67a06a7..c65e828a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } // 对于等待计划的情况,也需要先检查会话限制 @@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } } @@ -1433,53 +1431,76 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if s.isAccountSchedulableForSelection(stickyAccount) && + var stickyCacheMissReason string + + gatePass := s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) + + if rpmPass { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 + stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return &AccountSelectionResult{ - Account: stickyAccount, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) } } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - // 会话限制已满,继续到负载感知选择 + if stickyCacheMissReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + stickyCacheMissReason = "session_limit" + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + stickyCacheMissReason = "wait_queue_full" } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } else if !gatePass { + stickyCacheMissReason = "gate_check" + } else { + stickyCacheMissReason = "rpm_red" + } + + // 记录粘性缓存未命中的结构化日志 + if stickyCacheMissReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) } } @@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - WaitPlan: &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if s.cache != nil { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } } @@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { + return nil, legacyErr + } else if ok { return result, nil } } else { @@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) } } @@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true + selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) + if err != nil { + return nil, false, err + } + return selection, true, nil } } - return nil, false + return nil, false, nil } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } +func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { @@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - detail := "generic_unschedulable" - return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -// GetAccessToken 获取账号凭证 func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } +// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, +// system 字段仅保留 Claude Code 标识提示词。 +// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 +// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 +// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 +func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { + system = normalizeSystemParam(system) + + // 1. 提取原始 system prompt 文本 + var originalSystemText string + switch v := system.(type) { + case string: + originalSystemText = strings.TrimSpace(v) + case []any: + var parts []string + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + } + originalSystemText = strings.Join(parts, "\n\n") + } + + // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) + // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 + // 使用 string 格式会被 Anthropic 检测为第三方应用。 + claudeCodeSystemBlock := []map[string]any{ + { + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") + return body + } + + // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 + // 模型仍通过 messages 接收完整指令,保留客户端功能 + ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) + if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { + instrMsg, err1 := json.Marshal(map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, + }, + }) + ackMsg, err2 := json.Marshal(map[string]any{ + "role": "assistant", + "content": []map[string]any{ + {"type": "text", "text": "Understood. I will follow these instructions."}, + }, + }) + if err1 != nil || err2 != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") + return out + } + + // 重建 messages 数组:[instruction, ack, ...originalMessages] + items := [][]byte{instrMsg, ackMsg} + messagesResult := gjson.GetBytes(out, "messages") + if messagesResult.IsArray() { + messagesResult.ForEach(func(_, msg gjson.Result) bool { + items = append(items, []byte(msg.Raw)) + return true + }) + } + + if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { + out = next + } + } + + return out +} + type cacheControlPath struct { path string log string @@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) if policy.blockErr != nil { return nil, policy.blockErr } @@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + systemRewritten := false if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) + body = rewriteSystemForNonClaudeCode(body, parsed.System) + systemRewritten = true } - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); + // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 + // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) + _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) if !mimicMPT { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true @@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint - enableFP, enableMPT := true, false + enableFP, enableMPT, enableCCH := true, false, false if s.settingService != nil { - enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) @@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if fingerprint != nil { + body = syncBillingHeaderVersion(body, fingerprint.UserAgent) + } + // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) + if enableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) effectiveDropSet := mergeDropSets(policyFilterSet) - effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Match real Claude CLI traffic (per mitmproxy reports): - // messages requests typically use only oauth + interleaved-thinking. - // Also drop claude-code beta if a downstream client added it. + // Claude Code OAuth credentials are scoped to Claude Code. + // Non-haiku models MUST include claude-code beta for Anthropic to recognize + // this as a legitimate Claude Code request; without it, the request is + // rejected as third-party ("out of extra usage"). + // Haiku models are exempt from third-party detection and don't need it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) + if !strings.Contains(strings.ToLower(modelID), "haiku") { + requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} + } + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") @@ -5716,7 +5881,7 @@ type betaPolicyResult struct { } // evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { if s.settingService == nil { return betaPolicyResult{} } @@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } - switch rule.Action { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + switch effectiveAction { case BetaPolicyActionBlock: if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" // In the /v1/messages path, Forward() evaluates the policy first and caches the result; // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { if c != nil { if v, ok := c.Get(betaPolicyFilterSetKey); ok { if fs, ok := v.(map[string]struct{}); ok { @@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } } } - return s.evaluateBetaPolicy(ctx, "", account).filterSet + return s.evaluateBetaPolicy(ctx, "", account, model).filterSet } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. @@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { } } +// matchModelWhitelist checks if a model matches any pattern in the whitelist. +// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. +func matchModelWhitelist(model string, whitelist []string) bool { + for _, pattern := range whitelist { + if matchModelPattern(pattern, model) { + return true + } + } + return false +} + +// resolveRuleAction determines the effective action and error message for a rule given the request model. +// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. +// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. +func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { + if len(rule.ModelWhitelist) == 0 { + return rule.Action, rule.ErrorMessage + } + if matchModelWhitelist(model, rule.ModelWhitelist) { + return rule.Action, rule.ErrorMessage + } + if rule.FallbackAction != "" { + return rule.FallbackAction, rule.FallbackErrorMessage + } + return BetaPolicyActionPass, "" // default fallback: pass (fail-open) +} + // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( modelID string, ) ([]string, error) { // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) if policy.blockErr != nil { return nil, policy.blockErr } @@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { return nil, blockErr } @@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { if s.settingService == nil || len(tokens) == 0 { return nil } @@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { - if rule.Action != BetaPolicyActionBlock { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + if effectiveAction != BetaPolicyActionBlock { continue } if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() } -// postUsageBilling 统一处理使用量记录后的扣费逻辑: -// - 订阅/余额扣费 -// - API Key 配额更新 -// - API Key 限速用量更新 -// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +// postUsageBilling is the legacy fallback billing path used when the unified +// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply +// for atomic billing. This path only runs in tests or degraded mode. func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { billingCtx, cancel := detachedBillingContext(ctx) defer cancel() cost := p.Cost - // 1. 订阅 / 余额扣费 if p.IsSubscriptionBill { if cost.TotalCost > 0 { if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) } - deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) } } - // 2. API Key 配额 if p.shouldDeductAPIKeyQuota() { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } - // 3. API Key 限速用量 if p.shouldUpdateRateLimits() { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } } - // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) if p.shouldUpdateAccountQuota() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { @@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } } - finalizePostUsageBilling(p, deps) + // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing + // cache updates. The legacy path does DB writes directly; the finalize path + // does cache queue + notifications. Notifications are dispatched separately + // by the caller after recording the usage log. } func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { @@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount - if usageLog.MediaType != nil { - cmd.MediaType = *usageLog.MediaType - } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog } } - finalizePostUsageBilling(p, deps) + finalizePostUsageBilling(p, deps, result) return true, nil } -func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { if p == nil || p.Cost == nil || deps == nil { return } @@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) - // Balance low notification — use real-time balance from billing cache (not stale snapshot) - if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil { - oldBalance := p.User.Balance // fallback to snapshot - if deps.billingCacheService != nil { - if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil { - oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance - } + // Notification checks run async — all parameters are already captured, + // no dependency on the request context or upstream connection. + go notifyBalanceLow(p, deps, result) + go notifyAccountQuota(p, deps, result) +} + +// notifyBalanceLow sends balance low notification after deduction. +// When result.NewBalance is available (from DB transaction RETURNING), it is used directly +// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races. +func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyBalanceLow", "recover", r) } - deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) + }() + if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil { + slog.Debug("notifyBalanceLow: skipped", + "is_subscription", p.IsSubscriptionBill, + "actual_cost", p.Cost.ActualCost, + "user_nil", p.User == nil, + "service_nil", deps.balanceNotifyService == nil, + ) + return } - // Account quota notification (use same cost formula as postUsageBilling) - if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil { - accountCost := p.Cost.TotalCost * p.AccountRateMultiplier - deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost) + oldBalance := resolveOldBalance(p, result) + slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction", + "user_id", p.User.ID, + "old_balance", oldBalance, + "cost", p.Cost.ActualCost, + "notify_enabled", p.User.BalanceNotifyEnabled, + "threshold", p.User.BalanceNotifyThreshold, + "result_has_new_balance", result != nil && result.NewBalance != nil, + ) + deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) +} + +// resolveOldBalance returns the pre-deduction balance. +// Prefers the DB transaction result (newBalance + cost) over snapshot. +func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 { + if result != nil && result.NewBalance != nil { + return *result.NewBalance + p.Cost.ActualCost } + // Legacy fallback: snapshot balance from request context + return p.User.Balance +} + +// notifyAccountQuota sends account quota threshold notification after increment. +// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly +// to avoid a separate DB read that may see stale or concurrently-modified data. +func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyAccountQuota", "recover", r) + } + }() + if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil { + slog.Debug("notifyAccountQuota: skipped", + "total_cost", p.Cost.TotalCost, + "account_nil", p.Account == nil, + "is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(), + "service_nil", deps.balanceNotifyService == nil, + ) + return + } + accountCost := p.Cost.TotalCost * p.AccountRateMultiplier + var quotaState *AccountQuotaState + if result != nil { + quotaState = result.QuotaState + } + slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement", + "account_id", p.Account.ID, + "account_cost", accountCost, + "has_quota_state", quotaState != nil, + ) + deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState) } func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { @@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 type recordUsageOpts struct { - // ParsedRequest(可选,仅 Claude 路径传入) + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) ParsedRequest *ParsedRequest // EnableClaudePath 启用 Claude 路径特有逻辑: - // - MediaType 字段写入使用日志 + // - Claude Max 缓存计费策略 EnableClaudePath bool // 长上下文计费(仅 Gemini 路径需要) @@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu APIKeyService: input.APIKeyService, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ - ParsedRequest: input.ParsedRequest, EnableClaudePath: true, }) } @@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { - upstreamModel := result.UpstreamModel - if upstreamModel == "" { - upstreamModel = result.Model - } - usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, upstreamModel, + applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, + // Anthropic's input_tokens excludes cache_read and cache_creation (billed separately); + // OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason. UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage CacheReadTokens: result.Usage.CacheReadInputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens, }, - 1, // requestCount cost.TotalCost, ) } @@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, - BillingMode: resolveBillingMode(opts, result, cost), + BillingMode: resolveBillingMode(result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), - MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), @@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { +func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { var mode string switch { case cost != nil && cost.BillingMode != "": @@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost return &mode } -func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { - return nil -} - func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { return &subscription.ID @@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT := true, false + ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false if s.settingService != nil { - ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { @@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if ctFingerprint != nil && ctEnableFP { + body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) + } + if ctEnableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index 9a2d8716..efdbdb10 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) require.NoError(t, err) } + +// --- validatePlanPatch: other fields --- + +func ptrStr(s string) *string { return &s } +func ptrInt(i int) *int { return &i } +func ptrInt64(i int64) *int64 { return &i } +func ptrFloat(f float64) *float64 { return &f } + +func TestValidatePlanPatch_EmptyName(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")}) + require.Error(t, err) + require.Contains(t, err.Error(), "plan name") +} + +func TestValidatePlanPatch_ValidName(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ZeroGroupID(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "group") +} + +func TestValidatePlanPatch_NegativePrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)}) + require.Error(t, err) + require.Contains(t, err.Error(), "price") +} + +func TestValidatePlanPatch_ZeroPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "price") +} + +func TestValidatePlanPatch_ValidPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "validity days") +} + +func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")}) + require.Error(t, err) + require.Contains(t, err.Error(), "validity unit") +} + +func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_AllNil(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{}) + require.NoError(t, err) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 7602d162..a7724a5a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str Code: code, Attempts: 0, CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(verifyCodeTTL), } if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) @@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) } if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ - if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { + remaining := time.Until(data.ExpiresAt) + if remaining <= 0 { + return ErrInvalidVerifyCode + } + if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil { slog.Error("failed to update notify verify code attempts", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { @@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email } filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails)) + found := false for _, e := range user.BalanceNotifyExtraEmails { - if !strings.EqualFold(e.Email, email) { + if strings.EqualFold(e.Email, email) { + found = true + } else { filtered = append(filtered, e) } } + if !found { + return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found") + } user.BalanceNotifyExtraEmails = filtered return s.userRepo.Update(ctx, user) } diff --git a/frontend/src/components/account/QuotaDimensionRow.vue b/frontend/src/components/account/QuotaDimensionRow.vue index 1406faa9..e7fe2d0b 100644 --- a/frontend/src/components/account/QuotaDimensionRow.vue +++ b/frontend/src/components/account/QuotaDimensionRow.vue @@ -1,6 +1,7 @@ diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 77e437a8..68a68f29 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -2,6 +2,7 @@ import { ref, watch, computed } from 'vue' import { useI18n } from 'vue-i18n' import QuotaDimensionRow from './QuotaDimensionRow.vue' +import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account' const { t } = useI18n() @@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{ totalLimit: number | null dailyLimit: number | null weeklyLimit: number | null - dailyResetMode: 'rolling' | 'fixed' | null + dailyResetMode: QuotaResetMode | null dailyResetHour: number | null - weeklyResetMode: 'rolling' | 'fixed' | null + weeklyResetMode: QuotaResetMode | null weeklyResetDay: number | null weeklyResetHour: number | null resetTimezone: string | null quotaNotifyGlobalEnabled?: boolean quotaNotifyDailyEnabled?: boolean | null quotaNotifyDailyThreshold?: number | null - quotaNotifyDailyThresholdType?: string | null + quotaNotifyDailyThresholdType?: QuotaThresholdType | null quotaNotifyWeeklyEnabled?: boolean | null quotaNotifyWeeklyThreshold?: number | null - quotaNotifyWeeklyThresholdType?: string | null + quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null quotaNotifyTotalEnabled?: boolean | null quotaNotifyTotalThreshold?: number | null - quotaNotifyTotalThresholdType?: string | null + quotaNotifyTotalThresholdType?: QuotaThresholdType | null }>(), { quotaNotifyGlobalEnabled: false, quotaNotifyDailyEnabled: null, @@ -42,21 +43,21 @@ const emit = defineEmits<{ 'update:totalLimit': [value: number | null] 'update:dailyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null] - 'update:dailyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:dailyResetMode': [value: QuotaResetMode | null] 'update:dailyResetHour': [value: number | null] - 'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:weeklyResetMode': [value: QuotaResetMode | null] 'update:weeklyResetDay': [value: number | null] 'update:weeklyResetHour': [value: number | null] 'update:resetTimezone': [value: string | null] 'update:quotaNotifyDailyEnabled': [value: boolean | null] 'update:quotaNotifyDailyThreshold': [value: number | null] - 'update:quotaNotifyDailyThresholdType': [value: string | null] + 'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null] 'update:quotaNotifyWeeklyEnabled': [value: boolean | null] 'update:quotaNotifyWeeklyThreshold': [value: number | null] - 'update:quotaNotifyWeeklyThresholdType': [value: string | null] + 'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null] 'update:quotaNotifyTotalEnabled': [value: boolean | null] 'update:quotaNotifyTotalThreshold': [value: number | null] - 'update:quotaNotifyTotalThresholdType': [value: string | null] + 'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null] }>() const enabled = computed(() => @@ -89,11 +90,6 @@ watch(localEnabled, (val) => { } }) -// Whether any fixed mode is active (to show timezone selector) -const hasFixedMode = computed(() => - props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed' -) - // Common timezone options const timezoneOptions = [ 'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata', @@ -102,18 +98,6 @@ const timezoneOptions = [ 'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland', ] -// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone. -function getTimezoneOffsetLabel(tz: string): string { - try { - const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' }) - const parts = dtf.formatToParts(new Date()) - const tzPart = parts.find(p => p.type === 'timeZoneName') - return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : '' - } catch { - return '' - } -} - // Hours for dropdown (0-23) const hourOptions = Array.from({ length: 24 }, (_, i) => i) @@ -197,6 +181,7 @@ const dailyFixedHint = computed(() => :hint-fixed="dailyFixedHint" :hour-options="hourOptions" :day-options="dayOptions" + :timezone-options="timezoneOptions" @update:limit="emit('update:dailyLimit', $event)" @update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)" @update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)" @@ -223,6 +208,7 @@ const dailyFixedHint = computed(() => :hint-fixed="weeklyFixedHint" :hour-options="hourOptions" :day-options="dayOptions" + :timezone-options="timezoneOptions" @update:limit="emit('update:weeklyLimit', $event)" @update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)" @update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)" @@ -233,14 +219,6 @@ const dailyFixedHint = computed(() => @update:reset-timezone="emit('update:resetTimezone', $event)" /> - -
- - -
- -import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account' +import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account' defineProps<{ enabled: boolean | null threshold: number | null - thresholdType: string | null // "fixed" (default) or "percentage" + thresholdType: QuotaThresholdType | null }>() const emit = defineEmits<{ 'update:enabled': [value: boolean | null] 'update:threshold': [value: number | null] - 'update:thresholdType': [value: string | null] + 'update:thresholdType': [value: QuotaThresholdType | null] }>() @@ -43,7 +43,7 @@ const emit = defineEmits<{ /> - {{ getGroupNameById(gid) }} + {{ getGroupNameById(gid) }}

@@ -481,7 +481,7 @@ :key="accountId" class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20" > - {{ getRuleAccountLabel(accountId) }} + {{ getRuleAccountLabel(accountId) }} @@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' -import { platformTextClass } from '@/utils/platformColors' +import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -720,26 +720,6 @@ let abortController: AbortController | null = null // ── Platform config ── const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity'] -function getPlatformTextColor(platform: string): string { - switch (platform) { - case 'anthropic': return 'text-orange-600 dark:text-orange-400' - case 'openai': return 'text-emerald-600 dark:text-emerald-400' - case 'gemini': return 'text-blue-600 dark:text-blue-400' - case 'antigravity': return 'text-purple-600 dark:text-purple-400' - default: return 'text-gray-600 dark:text-gray-400' - } -} - -function getRateBadgeClass(platform: string): string { - switch (platform) { - case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' - case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' - case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' - case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400' - default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400' - } -} - // ── Helpers ── function formatDate(value: string): string { if (!value) return '-' From 9c09bd19b479c562098a58a821f69c42d7214204 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 00:42:40 +0800 Subject: [PATCH 078/122] fix: websearch features_config cleanup and pricing rules validation - Fix web_search_emulation toggle: explicitly write false for disabled platforms instead of leaving stale true from cloned features_config - Extract validatePricingEntries from validateChannelConfig for reuse - Validate account_stats_pricing_rules[].pricing in both Create and Update paths (negative prices, bad intervals, missing per_request price) --- backend/internal/service/channel_service.go | 22 ++++++++++++++++++--- frontend/src/views/admin/ChannelsView.vue | 8 ++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index d0698f0f..aa5e2ceb 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -566,15 +566,21 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { // validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 // Create 和 Update 共用此函数,避免重复。 func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { + if err := validatePricingEntries(pricing); err != nil { + return err + } + return validateNoConflictingMappings(mapping) +} + +// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验), +// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。 +func validatePricingEntries(pricing []ChannelModelPricing) error { if err := validateNoConflictingModels(pricing); err != nil { return err } if err := validatePricingIntervals(pricing); err != nil { return err } - if err := validateNoConflictingMappings(mapping); err != nil { - return err - } return validatePricingBillingMode(pricing) } @@ -684,6 +690,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } + for i, rule := range channel.AccountStatsPricingRules { + if err := validatePricingEntries(rule.Pricing); err != nil { + return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err) + } + } if err := s.repo.Create(ctx, channel); err != nil { return nil, fmt.Errorf("create channel: %w", err) @@ -712,6 +723,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } + for i, rule := range channel.AccountStatsPricingRules { + if err := validatePricingEntries(rule.Pricing); err != nil { + return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err) + } + } oldGroupIDs := s.getOldGroupIDs(ctx, id) diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 52d57d74..0b37a20d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -1032,15 +1032,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ } // Collect web_search_emulation (only anthropic platform supports it) + // Always write the key so that disabling in the UI correctly sets platform to false, + // rather than leaving a stale true value from the cloned features_config. const wsEmulation: Record = {} for (const section of form.platforms) { if (!section.enabled) continue - if (section.web_search_emulation && section.platform === 'anthropic') { - wsEmulation[section.platform] = true + if (section.platform === 'anthropic') { + wsEmulation[section.platform] = !!section.web_search_emulation } } if (Object.keys(wsEmulation).length > 0) { featuresConfig.web_search_emulation = wsEmulation + } else { + delete featuresConfig.web_search_emulation } return { group_ids, model_pricing, model_mapping, features_config: featuresConfig } From 0a4ece5f5be30db727bb9694588bfd0d0d99a26c Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 01:10:46 +0800 Subject: [PATCH 079/122] =?UTF-8?q?fix:=20audit=20round-3=20=E2=80=94=20pr?= =?UTF-8?q?oxy=20safety,=20intervals=20persistence,=20SMTP=20timeout,=20so?= =?UTF-8?q?rt=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Skip websearch provider when ProxyID is set but proxy not found (prevent silent direct connection bypass) - Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap keeps weights aligned - Allow empty platform in account_stats_pricing_rules (wildcard matching), only force anthropic default for main model_pricing - Add channel_account_stats_pricing_intervals table and repo layer support for interval-based pricing in account stats rules - calculateTokenStatsCost now uses interval pricing when available - Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO) to prevent goroutine leak on SMTP hang - Fix gofmt formatting issues - Web Search label: black text with red warning hint --- .../internal/handler/admin/channel_handler.go | 14 +++- backend/internal/pkg/websearch/manager.go | 17 +++-- .../channel_repo_account_stats_pricing.go | 74 +++++++++++++++++++ backend/internal/repository/email_cache.go | 10 +-- backend/internal/server/http.go | 6 ++ .../internal/service/account_stats_pricing.go | 31 ++++++-- .../service/balance_notify_service.go | 1 - backend/internal/service/email_service.go | 49 +++++++++++- .../internal/service/notify_email_entry.go | 1 - ...06_add_account_stats_pricing_intervals.sql | 19 +++++ frontend/src/views/admin/ChannelsView.vue | 4 +- 11 files changed, 199 insertions(+), 27 deletions(-) create mode 100644 backend/migrations/106_add_account_stats_pricing_intervals.sql diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 1a328551..88d27c47 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -249,9 +249,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe billingMode = service.BillingModeToken } platform := r.Platform - if platform == "" { - platform = service.PlatformAnthropic - } intervals := make([]service.PricingInterval, 0, len(r.Intervals)) for _, iv := range r.Intervals { intervals = append(intervals, service.PricingInterval{ @@ -349,6 +346,12 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) + // Main model_pricing requires a platform; default to anthropic for backward compatibility. + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { @@ -415,6 +418,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } input.ModelPricing = &pricing } if req.AccountStatsPricingRules != nil { diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 27592459..61faa616 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -200,13 +200,20 @@ func sortByStableRandomWeight(items []weighted) { if len(items) <= 1 { return } - factors := make([]float64, len(items)) - for i, item := range items { - factors[i] = float64(item.weight) * (0.5 + rand.Float64()) + type entry struct { + item weighted + factor float64 } - sort.Slice(items, func(i, j int) bool { - return factors[i] > factors[j] + entries := make([]entry, len(items)) + for i, item := range items { + entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())} + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].factor > entries[j].factor }) + for i, e := range entries { + items[i] = e.item + } } func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig { diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go index ef8f5177..9e00fed8 100644 --- a/backend/internal/repository/channel_repo_account_stats_pricing.go +++ b/backend/internal/repository/channel_repo_account_stats_pricing.go @@ -96,6 +96,27 @@ func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Contex if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate account stats model pricing: %w", err) } + + // Load intervals for all pricing entries. + var allPricingIDs []int64 + for _, pricings := range pricingMap { + for _, p := range pricings { + allPricingIDs = append(allPricingIDs, p.ID) + } + } + if len(allPricingIDs) > 0 { + intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for ruleID, pricings := range pricingMap { + for i := range pricings { + pricings[i].Intervals = intervalsMap[pricings[i].ID] + } + pricingMap[ruleID] = pricings + } + } + return pricingMap, nil } @@ -166,5 +187,58 @@ func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID in if err != nil { return fmt.Errorf("insert account stats model pricing: %w", err) } + // Persist intervals (mirrors channel_pricing_intervals logic). + for i := range pricing.Intervals { + iv := &pricing.Intervals[i] + iv.PricingID = pricing.ID + if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil { + return err + } + } return nil } + +// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry. +func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error { + return tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries. +func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + if len(pricingIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + result := make(map[int64][]service.PricingInterval) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing interval: %w", err) + } + result[iv.PricingID] = append(result[iv.PricingID], iv) + } + return result, rows.Err() +} diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 0eb6bef1..96a23a8e 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -12,11 +12,11 @@ import ( ) const ( - verifyCodeKeyPrefix = "verify_code:" - notifyVerifyKeyPrefix = "notify_verify:" - passwordResetKeyPrefix = "password_reset:" - passwordResetSentAtKeyPrefix = "password_reset_sent:" - notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" + verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 5165b059..d203bab2 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -4,6 +4,7 @@ package server import ( "context" "log" + "log/slog" "net/http" "time" @@ -82,6 +83,11 @@ func ProvideRouter( pc.ProxyID = *p.ProxyID if u, ok := proxyURLs[*p.ProxyID]; ok { pc.ProxyURL = u + } else { + // Proxy configured but not found — skip this provider to prevent direct connection. + slog.Warn("websearch: proxy not found for provider, skipping", + "provider", p.Type, "proxy_id", *p.ProxyID) + continue } } configs = append(configs, pc) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 8251dede..61c318d9 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -195,18 +195,33 @@ func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int } // calculateTokenStatsCost Token 计费。 +// If the pricing has intervals, find the matching interval by total token count +// and use its prices instead of the flat pricing fields. func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 { - deref := func(p *float64) float64 { - if p == nil { + p := pricing + if len(pricing.Intervals) > 0 { + totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens + if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil { + p = &ChannelModelPricing{ + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + } + } + } + deref := func(ptr *float64) float64 { + if ptr == nil { return 0 } - return *p + return *ptr } - cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) + - float64(tokens.OutputTokens)*deref(pricing.OutputPrice) + - float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + - float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + - float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) + cost := float64(tokens.InputTokens)*deref(p.InputPrice) + + float64(tokens.OutputTokens)*deref(p.OutputPrice) + + float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) + + float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) + + float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice) if cost <= 0 { return nil } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 5e9afcc8..5b7e413a 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -477,4 +477,3 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, account } return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay) } - diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index a94e0dde..425887cd 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "math/big" + "net" "net/smtp" "net/url" "strconv" @@ -152,6 +153,9 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) return s.SendEmailWithConfig(config, to, subject, body) } +const smtpDialTimeout = 10 * time.Second +const smtpIOTimeout = 20 * time.Second + // SendEmailWithConfig 使用指定配置发送邮件 func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { // Sanitize all SMTP header fields to prevent header injection (CR/LF removal). @@ -173,7 +177,46 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host) } - return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg)) + return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host) +} + +// sendMailPlain sends mail without TLS using a dialer with timeout. +func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error { + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := dialer.Dial("tcp", addr) + if err != nil { + return fmt.Errorf("smtp dial: %w", err) + } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) + defer func() { _ = conn.Close() }() + + client, err := smtp.NewClient(conn, host) + if err != nil { + return fmt.Errorf("new smtp client: %w", err) + } + defer func() { _ = client.Close() }() + + if err = client.Auth(auth); err != nil { + return fmt.Errorf("smtp auth: %w", err) + } + if err = client.Mail(from); err != nil { + return fmt.Errorf("smtp mail: %w", err) + } + if err = client.Rcpt(to); err != nil { + return fmt.Errorf("smtp rcpt: %w", err) + } + w, err := client.Data() + if err != nil { + return fmt.Errorf("smtp data: %w", err) + } + if _, err = w.Write(msg); err != nil { + return fmt.Errorf("write msg: %w", err) + } + if err = w.Close(); err != nil { + return fmt.Errorf("close writer: %w", err) + } + _ = client.Quit() + return nil } // sendMailTLS 使用TLS发送邮件 @@ -184,10 +227,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, MinVersion: tls.VersionTLS12, } - conn, err := tls.Dial("tcp", addr, tlsConfig) + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) if err != nil { return fmt.Errorf("tls dial: %w", err) } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) defer func() { _ = conn.Close() }() client, err := smtp.NewClient(conn, host) diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go index d181200b..625185b2 100644 --- a/backend/internal/service/notify_email_entry.go +++ b/backend/internal/service/notify_email_entry.go @@ -79,4 +79,3 @@ func MarshalNotifyEmails(entries []NotifyEmailEntry) string { } return string(data) } - diff --git a/backend/migrations/106_add_account_stats_pricing_intervals.sql b/backend/migrations/106_add_account_stats_pricing_intervals.sql new file mode 100644 index 00000000..5ae10655 --- /dev/null +++ b/backend/migrations/106_add_account_stats_pricing_intervals.sql @@ -0,0 +1,19 @@ +-- Add intervals table for account stats pricing rules (mirrors channel_pricing_intervals). +CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_intervals ( + id BIGSERIAL PRIMARY KEY, + pricing_id BIGINT NOT NULL REFERENCES channel_account_stats_model_pricing(id) ON DELETE CASCADE, + min_tokens INT NOT NULL DEFAULT 0, + max_tokens INT, + tier_label VARCHAR(50), + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + per_request_price NUMERIC(20,12), + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_account_stats_pricing_intervals_pricing_id + ON channel_account_stats_pricing_intervals (pricing_id); diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 0b37a20d..e4452b98 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -328,10 +328,10 @@

-
From b402c367d331c24b1507bdbc0596c06c933b1d55 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 01:38:42 +0800 Subject: [PATCH 080/122] fix: add opportunistic STARTTLS to sendMailPlain for 587 port compatibility smtp.SendMail automatically upgrades to STARTTLS when the server supports it. Our replacement sendMailPlain skipped this, causing credentials to be sent in plaintext on port 587. Add STARTTLS negotiation before Auth to restore the original security behavior. --- backend/internal/service/email_service.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 425887cd..9cfd3bbd 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -196,6 +196,14 @@ func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to strin } defer func() { _ = client.Close() }() + // Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it. + // This mirrors the behavior of smtp.SendMail which we replaced for timeout support. + if ok, _ := client.Extension("STARTTLS"); ok { + if err = client.StartTLS(&tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}); err != nil { + return fmt.Errorf("starttls: %w", err) + } + } + if err = client.Auth(auth); err != nil { return fmt.Errorf("smtp auth: %w", err) } From 9e0d12d3b03e6b1cc8cf96a0c59f6b0a288fab85 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 07:22:22 +0800 Subject: [PATCH 081/122] fix: show websearch API key visibility/copy buttons for saved providers The buttons were hidden because v-if only checked provider.api_key, which is always empty for saved providers (backend sanitizes it). Now also checks api_key_configured. Copy button is disabled when no actual key is available (only configured placeholder shown). --- backend/cmd/server/VERSION | 2 +- .../internal/handler/admin/setting_handler.go | 2 +- backend/internal/service/websearch_config.go | 22 ++++++++++++++++++ frontend/src/views/admin/SettingsView.vue | 23 ++++++++++++------- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 68dda295..5657b5e3 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.51 +0.1.112.3 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 2324cc70..2a87e95c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1939,7 +1939,7 @@ func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), cfg)) + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg)) } // UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置 diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index 5658cec3..239e882a 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -277,6 +277,28 @@ func TestWebSearch(ctx context.Context, query string) (*WebSearchTestResult, err }, nil } +// PopulateWebSearchUsage returns a copy with quota usage populated from Redis (api_key kept as-is). +func PopulateWebSearchUsage(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { + if cfg == nil { + return nil + } + out := *cfg + out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers)) + + mgr := getWebSearchManager() + + for i, p := range cfg.Providers { + out.Providers[i] = p + out.Providers[i].APIKeyConfigured = p.APIKey != "" + + if mgr != nil { + used, _ := mgr.GetUsage(ctx, p.Type) + out.Providers[i].QuotaUsed = used + } + } + return &out +} + // SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated. func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { if cfg == nil { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 3ef1c0ba..12f67187 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1775,8 +1775,8 @@ @click.stop /> - - {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} + + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} @@ -1797,10 +1797,10 @@ v-model="provider.api_key" :type="apiKeyVisible[pIdx] ? 'text' : 'password'" class="input w-full text-sm" - :class="provider.api_key ? 'pr-16' : ''" + :class="(provider.api_key || provider.api_key_configured) ? 'pr-16' : ''" :placeholder="provider.api_key_configured ? '••••••••' : t('admin.settings.webSearchEmulation.apiKeyPlaceholder')" /> -
+
-
+
{{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
+
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} +
+ {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }}
@@ -3164,9 +3167,13 @@ async function loadWebSearchConfig() { async function saveWebSearchConfig(): Promise { try { + const providers = webSearchConfig.providers.map((p: WebSearchProviderConfig) => ({ + ...p, + quota_limit: typeof p.quota_limit === 'number' && p.quota_limit > 0 ? p.quota_limit : 0, + })) await adminAPI.settings.updateWebSearchEmulationConfig({ enabled: webSearchConfig.enabled, - providers: webSearchConfig.providers as WebSearchProviderConfig[], + providers, }) return true } catch (err: unknown) { From 1e6912ea2e12beb052d1e0523aa321be619de57f Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 07:43:08 +0800 Subject: [PATCH 082/122] fix: gofmt formatting across all Go source files --- .../internal/handler/admin/setting_handler.go | 10 ++-- backend/internal/handler/dto/settings.go | 10 ++-- backend/internal/handler/dto/types.go | 8 +-- .../internal/payment/load_balancer_test.go | 2 +- .../internal/payment/provider/alipay_test.go | 6 +- backend/internal/service/account.go | 27 ++++++--- .../service/balance_notify_email_body_test.go | 18 +++--- backend/internal/service/domain_constants.go | 4 +- .../payment_config_plans_validation_test.go | 6 +- .../service/payment_config_providers.go | 59 +++++++++++++++---- .../service/payment_config_providers_test.go | 2 +- .../service/payment_config_service.go | 36 +++++------ .../service/setting_service_public_test.go | 2 +- .../service/setting_service_update_test.go | 4 +- backend/internal/service/settings_view.go | 12 ++-- backend/internal/service/usage_billing.go | 13 ++++ backend/internal/service/user.go | 2 +- backend/internal/service/user_service.go | 6 +- 18 files changed, 143 insertions(+), 84 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 2a87e95c..b50cad96 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -310,11 +310,11 @@ type UpdateSettingsRequest struct { EnableCCHSigning *bool `json:"enable_cch_signing"` // Balance low notification - BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index d218490a..ef285a44 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -150,11 +150,11 @@ type SystemSettings struct { PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` // Balance low notification - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index afb782b0..1aab1dbb 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -19,11 +19,11 @@ type User struct { UpdatedAt time.Time `json:"updated_at"` // 余额不足通知 - BalanceNotifyEnabled bool `json:"balance_notify_enabled"` - BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + BalanceNotifyEnabled bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` - TotalRecharged float64 `json:"total_recharged"` + TotalRecharged float64 `json:"total_recharged"` APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 568b56a3..04b3c25b 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) { wantIDs: nil, }, { - name: "empty candidates returns empty", + name: "empty candidates returns empty", candidates: nil, paymentType: "alipay", orderAmount: 10, diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 1b9d66ba..7b0ce0d8 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) { errSubstr: "privateKey", }, { - name: "nil config map returns error for appId", - config: map[string]string{}, - wantErr: true, + name: "nil config map returns error for appId", + config: map[string]string{}, + wantErr: true, errSubstr: "appId", }, } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index cacfb240..52db3073 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1533,39 +1533,48 @@ func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64 } func (a *Account) GetQuotaNotifyDailyEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimDaily); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimDaily) + return e } func (a *Account) GetQuotaNotifyDailyThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimDaily); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimDaily) + return t } func (a *Account) GetQuotaNotifyDailyThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimDaily); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimDaily) + return tt } func (a *Account) GetQuotaNotifyWeeklyEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return e } func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return t } func (a *Account) GetQuotaNotifyWeeklyThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly) + return tt } func (a *Account) GetQuotaNotifyTotalEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimTotal); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimTotal) + return e } func (a *Account) GetQuotaNotifyTotalThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimTotal); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimTotal) + return t } func (a *Account) GetQuotaNotifyTotalThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimTotal); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimTotal) + return tt } // nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go index 9baf164e..aee5a5bc 100644 --- a/backend/internal/service/balance_notify_email_body_test.go +++ b/backend/internal/service/balance_notify_email_body_test.go @@ -65,15 +65,15 @@ func TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton(t *testing.T) { func TestBuildQuotaAlertEmailBody_AllFieldsPresent(t *testing.T) { s := &BalanceNotifyService{} body := s.buildQuotaAlertEmailBody( - 42, // accountID - "acc-foo", // accountName - "anthropic", // platform - "日限额 / Daily", // dimLabel - 750.50, // used - 1000.0, // limit - 249.50, // remaining - "$249.50", // thresholdDisplay - "MySite", // siteName + 42, // accountID + "acc-foo", // accountName + "anthropic", // platform + "日限额 / Daily", // dimLabel + 750.50, // used + 1000.0, // limit + 249.50, // remaining + "$249.50", // thresholdDisplay + "MySite", // siteName ) require.Contains(t, body, "MySite") diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 896ba59f..bdced29a 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -251,8 +251,8 @@ const ( SettingKeyEnableCCHSigning = "enable_cch_signing" // Balance Low Notification - SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 - SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) + SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 + SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL // Account Quota Notification diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index efdbdb10..bcbe901f 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -131,9 +131,9 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { // --- validatePlanPatch: other fields --- -func ptrStr(s string) *string { return &s } -func ptrInt(i int) *int { return &i } -func ptrInt64(i int64) *int64 { return &i } +func ptrStr(s string) *string { return &s } +func ptrInt(i int) *int { return &i } +func ptrInt64(i int64) *int64 { return &i } func ptrFloat(f float64) *float64 { return &f } func TestValidatePlanPatch_EmptyName(t *testing.T) { diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 10181914..0c71ab29 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. @@ -46,8 +47,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, - PaymentMode: inst.PaymentMode, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) if err != nil { @@ -110,10 +111,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } + allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -221,6 +224,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) + // Cascade: turning off refund_enabled also disables allow_user_refund + if !*req.RefundEnabled { + u.SetAllowUserRefund(false) + } + } + if req.AllowUserRefund != nil { + // Only allow enabling when refund_enabled is true + if *req.AllowUserRefund { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil && inst.RefundEnabled { + u.SetAllowUserRefund(true) + } + } else { + u.SetAllowUserRefund(false) + } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -228,6 +246,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in return u.Save(ctx) } +// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund. +func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.AllowUserRefundEQ(true), + paymentproviderinstance.RefundEnabledEQ(true), + ).Select(paymentproviderinstance.FieldID).All(ctx) + if err != nil { + return nil, err + } + ids := make([]string, 0, len(instances)) + for _, inst := range instances { + ids = append(ids, strconv.FormatInt(int64(inst.ID), 10)) + } + return ids, nil +} + func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) { inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) if err != nil { diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index e71eb9f5..2aaa874f 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string + field string wantSen bool }{ // Sensitive fields (contain key/secret/private/password/pkey patterns) diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..cce31f4d 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,28 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` + AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 6dfa627c..5cf1e860 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { repo := &settingPublicRepoStub{ values: map[string]string{ - SettingKeyTableDefaultPageSize: "50", + SettingKeyTableDefaultPageSize: "50", SettingKeyTablePageSizeOptions: "[20,50,100]", }, } diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index 28c7ad02..e62218b4 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { svc := NewSettingService(repo, &config.Config{}) err := svc.UpdateSettings(context.Background(), &SystemSettings{ - TableDefaultPageSize: 50, + TableDefaultPageSize: 50, TablePageSizeOptions: []int{20, 50, 100}, }) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions]) err = svc.UpdateSettings(context.Background(), &SystemSettings{ - TableDefaultPageSize: 1000, + TableDefaultPageSize: 1000, TablePageSizeOptions: []int{20, 100}, }) require.NoError(t, err) diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 57f3746a..ec20fe0a 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -108,8 +108,8 @@ type SystemSettings struct { EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) // Balance low notification - BalanceLowNotifyEnabled bool - BalanceLowNotifyThreshold float64 + BalanceLowNotifyEnabled bool + BalanceLowNotifyThreshold float64 BalanceLowNotifyRechargeURL string // Account quota notification @@ -155,10 +155,10 @@ type PublicSettings struct { OIDCOAuthProviderName string Version string - BalanceLowNotifyEnabled bool - AccountQuotaNotifyEnabled bool - BalanceLowNotifyThreshold float64 - BalanceLowNotifyRechargeURL string + BalanceLowNotifyEnabled bool + AccountQuotaNotifyEnabled bool + BalanceLowNotifyThreshold float64 + BalanceLowNotifyRechargeURL string } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go index 73b05743..30495624 100644 --- a/backend/internal/service/usage_billing.go +++ b/backend/internal/service/usage_billing.go @@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 { return *v } +// AccountQuotaState holds the post-increment quota state returned by the DB transaction. +// All values are post-update (i.e., already include the increment). +type AccountQuotaState struct { + TotalUsed float64 + TotalLimit float64 + DailyUsed float64 + DailyLimit float64 + WeeklyUsed float64 + WeeklyLimit float64 +} + type UsageBillingApplyResult struct { Applied bool APIKeyQuotaExhausted bool + NewBalance *float64 // post-deduction balance (nil = no balance deduction) + QuotaState *AccountQuotaState // post-increment quota state (nil = no quota increment) } type UsageBillingRepository interface { diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index d3d8c954..59f8aa6b 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -32,7 +32,7 @@ type User struct { // 余额不足通知 BalanceNotifyEnabled bool - BalanceNotifyThresholdType string // "fixed" (default) | "percentage" + BalanceNotifyThresholdType string // "fixed" (default) | "percentage" BalanceNotifyThreshold *float64 BalanceNotifyExtraEmails []NotifyEmailEntry TotalRecharged float64 diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7724a5a..3490e804 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -13,9 +13,9 @@ import ( ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") ) From 7c7292935e8cefb4f2a2ebbf732bd923cb55c8e7 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 08:03:27 +0800 Subject: [PATCH 083/122] feat: websearch quota enhancements and balance notify hint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - QuotaLimit changed to *int64 (null=unlimited, >0=limited) - Add reset-usage endpoint (POST /admin/settings/web-search-emulation/reset-usage) - Show quota usage in header always (collapsed and expanded) - Add reset quota button in expanded provider view - Quota input: empty=unlimited with ∞ placeholder, must be >0 if set - Add email verification hint on balance notify card --- backend/cmd/server/VERSION | 2 +- .../internal/handler/admin/setting_handler.go | 23 +++++++++- backend/internal/pkg/websearch/manager.go | 9 ++++ backend/internal/server/http.go | 9 +++- backend/internal/server/routes/admin.go | 1 + backend/internal/service/websearch_config.go | 15 +++++-- .../internal/service/websearch_config_test.go | 18 ++++---- frontend/src/api/admin/settings.ts | 11 ++++- .../user/profile/ProfileBalanceNotifyCard.vue | 1 + frontend/src/i18n/locales/en.ts | 9 +++- frontend/src/i18n/locales/zh.ts | 9 +++- frontend/src/views/admin/SettingsView.vue | 42 +++++++++++++++---- 12 files changed, 121 insertions(+), 28 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 5657b5e3..630554d9 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.112.3 +0.1.112.4 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index b50cad96..9b49150c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1962,7 +1962,28 @@ func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), updated)) + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated)) +} + +// ResetWebSearchUsage 重置指定 provider 的配额用量 +// POST /api/v1/admin/settings/web-search-emulation/reset-usage +func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) { + var req struct { + ProviderType string `json:"provider_type"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if req.ProviderType == "" { + response.BadRequest(c, "provider_type is required") + return + } + if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) } // TestWebSearchEmulation 测试 Web Search 搜索 diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 61faa616..307aa1e9 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -447,6 +447,15 @@ func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 { return result } +// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0. +func (m *Manager) ResetUsage(ctx context.Context, providerType string) error { + if m.redis == nil { + return nil + } + key := quotaRedisKey(providerType) + return m.redis.Del(ctx, key).Err() +} + // --- Provider factory --- func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider { diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d203bab2..023e40bb 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -73,7 +73,7 @@ func ProvideRouter( pc := websearch.ProviderConfig{ Type: p.Type, APIKey: p.APIKey, - QuotaLimit: p.QuotaLimit, + QuotaLimit: derefInt64(p.QuotaLimit), ExpiresAt: p.ExpiresAt, } if p.SubscribedAt != nil { @@ -141,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 } } + +func derefInt64(p *int64) int64 { + if p == nil { + return 0 + } + return *p +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0a7b7a8b..9af0fd8e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -411,6 +411,7 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig) adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig) adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation) + adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage) } } diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index 239e882a..f528a35b 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -24,7 +24,7 @@ type WebSearchProviderConfig struct { Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses APIKeyConfigured bool `json:"api_key_configured"` // read-only mask - QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited + QuotaLimit *int64 `json:"quota_limit"` // nil = unlimited, >0 = limited SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current usage from Redis ProxyID *int64 `json:"proxy_id"` // optional proxy association @@ -52,8 +52,8 @@ func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error { if !validProviderTypes[p.Type] { return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type) } - if p.QuotaLimit < 0 { - return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i) + if p.QuotaLimit != nil && *p.QuotaLimit < 0 { + return fmt.Errorf("provider[%d]: quota_limit must be > 0 or null", i) } if seen[p.Type] { return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type) @@ -299,6 +299,15 @@ func PopulateWebSearchUsage(ctx context.Context, cfg *WebSearchEmulationConfig) return &out } +// ResetWebSearchUsage deletes the Redis quota key for the given provider type. +func ResetWebSearchUsage(ctx context.Context, providerType string) error { + mgr := getWebSearchManager() + if mgr == nil { + return fmt.Errorf("web search manager not initialized") + } + return mgr.ResetUsage(ctx, providerType) +} + // SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated. func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { if cfg == nil { diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 4aea98b7..8cd50d0d 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -17,8 +17,8 @@ func TestValidateWebSearchConfig_Valid(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", QuotaLimit: 1000}, - {Type: "tavily", QuotaLimit: 500}, + {Type: "brave", QuotaLimit: int64Ptr(1000)}, + {Type: "tavily", QuotaLimit: int64Ptr(500)}, }, } require.NoError(t, validateWebSearchConfig(cfg)) @@ -42,9 +42,9 @@ func TestValidateWebSearchConfig_InvalidType(t *testing.T) { func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}}, + Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: int64Ptr(-1)}}, } - require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0") + require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be > 0 or null") } func TestValidateWebSearchConfig_DuplicateType(t *testing.T) { @@ -57,9 +57,9 @@ func TestValidateWebSearchConfig_DuplicateType(t *testing.T) { require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type") } -func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) { +func TestValidateWebSearchConfig_NilQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}}, + Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: nil}}, } require.NoError(t, validateWebSearchConfig(cfg)) } @@ -92,7 +92,7 @@ func TestParseWebSearchConfigJSON_BackwardCompatibility(t *testing.T) { cfg := parseWebSearchConfigJSON(raw) require.True(t, cfg.Enabled) require.Len(t, cfg.Providers, 1) - require.Equal(t, int64(1000), cfg.Providers[0].QuotaLimit) + require.Equal(t, int64(1000), *cfg.Providers[0].QuotaLimit) } // --- SanitizeWebSearchConfig --- @@ -126,12 +126,12 @@ func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", APIKey: "secret", QuotaLimit: 1000}, + {Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(1000)}, }, } out := SanitizeWebSearchConfig(context.Background(), cfg) require.True(t, out.Enabled) - require.Equal(t, int64(1000), out.Providers[0].QuotaLimit) + require.Equal(t, int64(1000), *out.Providers[0].QuotaLimit) } func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 4b5eb242..aa1d0f82 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -502,7 +502,7 @@ export interface WebSearchProviderConfig { type: 'brave' | 'tavily' api_key: string api_key_configured: boolean - quota_limit: number + quota_limit: number | null subscribed_at: number | null quota_used?: number proxy_id: number | null @@ -547,6 +547,12 @@ export async function testWebSearchEmulation( return data } +export async function resetWebSearchUsage( + payload: { provider_type: string } +): Promise { + await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +} + export const settingsAPI = { getSettings, updateSettings, @@ -565,7 +571,8 @@ export const settingsAPI = { updateBetaPolicySettings, getWebSearchEmulationConfig, updateWebSearchEmulationConfig, - testWebSearchEmulation + testWebSearchEmulation, + resetWebSearchUsage } export default settingsAPI diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue index 3a84fd6b..c4d04153 100644 --- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue +++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue @@ -48,6 +48,7 @@
+

{{ t('profile.balanceNotify.extraEmailsHint') }}

diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c8acf6c0..9baddc43 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -914,6 +914,7 @@ export default { thresholdPlaceholder: 'Enter amount', systemDefault: 'System Default', extraEmails: 'Notification Emails', + extraEmailsHint: 'You must add and verify an email address to receive low balance alerts', primaryEmail: 'Primary', noExtraEmails: 'No extra notification emails', enterEmail: 'Enter email address', @@ -4435,10 +4436,14 @@ export default { copyApiKey: 'Copy', copied: 'Copied', quotaLimit: 'Quota Limit', - quotaLimitHint: '0 = unlimited', + quotaLimitHint: 'Leave empty for unlimited; must be > 0 if set', + quotaLimitMustBePositive: 'Quota limit must be greater than 0', subscribedAt: 'Subscribed At', - subscribedAtHint: 'Quota resets monthly from this date', + subscribedAtHint: 'Quota resets monthly from this date; leave empty to disable auto-reset', quotaUsage: 'Usage', + resetUsage: 'Reset', + resetUsageConfirm: 'Reset usage counter for this provider?', + resetUsageSuccess: 'Usage counter reset', proxy: 'Proxy', removeProvider: 'Remove', noProviders: 'No search providers configured', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 499ed9cb..af8da265 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -918,6 +918,7 @@ export default { thresholdPlaceholder: '输入金额', systemDefault: '系统默认值', extraEmails: '通知邮箱', + extraEmailsHint: '必须添加并验证邮箱后,余额不足时才能收到提醒邮件', primaryEmail: '主邮箱', noExtraEmails: '暂无额外通知邮箱', enterEmail: '输入邮箱地址', @@ -4597,10 +4598,14 @@ export default { copyApiKey: '复制', copied: '已复制', quotaLimit: '配额上限', - quotaLimitHint: '0 表示无限制', + quotaLimitHint: '留空表示无限制;填写时必须大于 0', + quotaLimitMustBePositive: '配额上限必须大于 0', subscribedAt: '订阅时间', - subscribedAtHint: '配额从此日期起每月自动重置', + subscribedAtHint: '配额从此日期起每月自动重置;留空则不自动重置', quotaUsage: '用量', + resetUsage: '重置', + resetUsageConfirm: '确定要重置此服务商的用量计数吗?', + resetUsageSuccess: '用量已重置', proxy: '代理', removeProvider: '删除', noProviders: '未配置搜索服务商', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 12f67187..c57d2033 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1774,9 +1774,9 @@ class="w-36" @click.stop /> - - - {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} + + + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} @@ -1835,7 +1835,7 @@
- +

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

@@ -1853,7 +1853,7 @@
{{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
+
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} +
@@ -3118,6 +3126,19 @@ function quotaPercentage(provider: WebSearchProviderConfig): number { return ((provider.quota_used ?? 0) / provider.quota_limit) * 100 } +async function resetWebSearchUsage(idx: number) { + const provider = webSearchConfig.providers[idx] + if (!provider) return + if (!confirm(t('admin.settings.webSearchEmulation.resetUsageConfirm'))) return + try { + await adminAPI.settings.resetWebSearchUsage({ provider_type: provider.type }) + provider.quota_used = 0 + appStore.showSuccess(t('admin.settings.webSearchEmulation.resetUsageSuccess')) + } catch (err: unknown) { + appStore.showError(extractApiErrorMessage(err, t('common.error'))) + } +} + async function copyApiKey(idx: number) { const key = webSearchConfig.providers[idx]?.api_key if (!key) { @@ -3167,9 +3188,16 @@ async function loadWebSearchConfig() { async function saveWebSearchConfig(): Promise { try { + for (const p of webSearchConfig.providers) { + const raw = p.quota_limit + if (raw != null && Number(raw) !== 0 && Number(raw) < 1) { + appStore.showError(t('admin.settings.webSearchEmulation.quotaLimitMustBePositive')) + return false + } + } const providers = webSearchConfig.providers.map((p: WebSearchProviderConfig) => ({ ...p, - quota_limit: typeof p.quota_limit === 'number' && p.quota_limit > 0 ? p.quota_limit : 0, + quota_limit: Number(p.quota_limit) > 0 ? Number(p.quota_limit) : null, })) await adminAPI.settings.updateWebSearchEmulationConfig({ enabled: webSearchConfig.enabled, From 9028d2085f12f5ea30fc94165af21b767e455643 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 08:42:28 +0800 Subject: [PATCH 084/122] test: add unit tests for billing, websearch, and notify systems Billing (25 tests): - CalculateCostUnified: nil resolver fallback, token/per_request/image modes - GetModelPricingWithChannel: nil/partial/full channel overrides - resolveAccountStatsCost: four-level priority chain integration tests WebSearch (18 tests): - PopulateWebSearchUsage: nil input, manager states, QuotaLimit nil/*int64 - ResetWebSearchUsage: nil manager error - Manager.ResetUsage: nil Redis - shouldEmulateWebSearch: full decision chain (8 scenarios) Notify (36 tests): - ParseNotifyEmails/MarshalNotifyEmails: old/new format, roundtrip - crossedDownward: boundary values, threshold semantics - checkQuotaDimCrossings: mixed dimensions, disabled/zero skip --- .../internal/pkg/websearch/manager_test.go | 8 + .../service/account_stats_pricing_test.go | 242 ++++++++++++++++ .../service/balance_notify_check_test.go | 224 +++++++++++++++ .../internal/service/billing_service_test.go | 120 ++++++++ .../service/billing_service_unified_test.go | 258 ++++++++++++++++++ .../gateway_websearch_emulation_test.go | 238 ++++++++++++++++ .../service/notify_email_entry_test.go | 156 +++++++++++ .../internal/service/websearch_config_test.go | 123 +++++++++ 8 files changed, 1369 insertions(+) create mode 100644 backend/internal/service/billing_service_unified_test.go create mode 100644 backend/internal/service/notify_email_entry_test.go diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index a4beef68..cbcf1b76 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -313,3 +313,11 @@ func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) { require.NoError(t, err) require.NotNil(t, c) } + +// --- ResetUsage --- + +func TestManager_ResetUsage_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + err := m.ResetUsage(context.Background(), "brave") + require.NoError(t, err) +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 23409d5e..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -3,7 +3,9 @@ package service import ( + "context" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -527,3 +529,243 @@ func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 require.InDelta(t, 0.95, *result, 1e-12) } + +// --------------------------------------------------------------------------- +// resolveAccountStatsCost — integration tests covering the 4-level priority chain +// --------------------------------------------------------------------------- + +func TestResolveAccountStatsCost_NilChannelService(t *testing.T) { + result := resolveAccountStatsCost( + context.Background(), + nil, // channelService is nil + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) { + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "", // empty upstream model + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) { + // Group 99 is NOT in the cache, so GetChannelForGroup returns nil + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 99, "claude-sonnet-4", // groupID 99 has no channel + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.01), + OutputPrice: testPtrFloat64(0.02), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService not needed when custom rule hits + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored because custom rule hits + ) + require.NotNil(t, result) + // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 0.75, // totalCost = 0.75 + ) + require.NotNil(t, result) + require.InDelta(t, 0.75, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + UsageTokens{}, 1, 0.0, // totalCost = 0 + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, // not enabled + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored + ) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + // BillingService with no pricing for the model + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "totally-unknown-model", + tokens, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService is nil + 1, 10, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) { + // Both custom rule and ApplyPricingToAccountStats are configured; + // custom rule should take precedence. + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.05), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins) + ) + require.NotNil(t, result) + // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost) + require.InDelta(t, 5.0, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// helpers for resolveAccountStatsCost tests +// --------------------------------------------------------------------------- + +// newTestChannelServiceForStats creates a ChannelService with a single channel +// mapped to the given groupID, suitable for resolveAccountStatsCost tests. +func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService { + t.Helper() + cache := newEmptyChannelCache() + cache.channelByGroupID[groupID] = channel + cache.groupPlatform[groupID] = platform + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go index 955f3129..7bb4cf9e 100644 --- a/backend/internal/service/balance_notify_check_test.go +++ b/backend/internal/service/balance_notify_check_test.go @@ -178,3 +178,227 @@ func TestGetSiteName_Configured(t *testing.T) { repo.data[SettingKeySiteName] = "My Site" require.Equal(t, "My Site", s.getSiteName(context.Background())) } + +// ---------- crossedDownward ---------- + +func TestCrossedDownward_CrossesBelow(t *testing.T) { + // oldBalance > threshold, newBalance < threshold → true + require.True(t, crossedDownward(100, 5, 10)) +} + +func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) { + // oldBalance > threshold, newBalance == threshold → false (not below) + require.False(t, crossedDownward(100, 10, 10)) +} + +func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) { + // oldBalance == threshold, newBalance < threshold → true + // (at-or-above → below counts as a crossing) + require.True(t, crossedDownward(10, 5, 10)) +} + +func TestCrossedDownward_AlreadyBelow(t *testing.T) { + // oldBalance < threshold → false (already below, no new crossing) + require.False(t, crossedDownward(5, 3, 10)) +} + +func TestCrossedDownward_BothAbove(t *testing.T) { + // oldBalance > threshold, newBalance > threshold → false (no crossing) + require.False(t, crossedDownward(100, 50, 10)) +} + +func TestCrossedDownward_ZeroThreshold(t *testing.T) { + // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives + // Typical case: positive balances should not fire when threshold is 0. + require.False(t, crossedDownward(10, 5, 0)) + require.False(t, crossedDownward(0, 0, 0)) +} + +func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) { + // Edge case: newBalance goes negative with threshold=0. + require.True(t, crossedDownward(5, -1, 0)) +} + +func TestCrossedDownward_NegativeValues(t *testing.T) { + // Both already negative, threshold is positive → no crossing (already below). + require.False(t, crossedDownward(-5, -10, 10)) +} + +func TestCrossedDownward_LargeDecrement(t *testing.T) { + // A single large deduction crosses the threshold. + require.True(t, crossedDownward(1000, 0.5, 100)) +} + +func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) { + // A tiny deduction stays above threshold. + require.False(t, crossedDownward(100, 99.99, 10)) +} + +// ---------- checkQuotaDimCrossings ---------- + +func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // Empty dims → no crossing, no panic. + s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite") + s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: false, // disabled + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Disabled dimension should be skipped even if crossing would occur. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 0, // zero threshold + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Zero threshold → skipped. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 800, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200 + // Negative resolved threshold → skipped. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 1200, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700 + // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing. + dims := []quotaDim{ + { + name: quotaDimWeekly, + enabled: true, + threshold: 30, + thresholdType: thresholdTypePercentage, + currentUsed: 500, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // limit=0 → resolvedThreshold returns 0 → skipped. + dims := []quotaDim{ + { + name: quotaDimTotal, + enabled: true, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 50, + limit: 0, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // dim1: no crossing (both below effective threshold) + // dim2: disabled (skipped) + // dim3: zero threshold (skipped) + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below + limit: 1000, + }, + { + name: quotaDimWeekly, + enabled: false, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 900, + limit: 1000, + }, + { + name: quotaDimTotal, + enabled: true, + threshold: 0, + thresholdType: thresholdTypeFixed, + currentUsed: 500, + limit: 1000, + }, + } + // None should trigger. No panic expected. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 6f6c41ce..2cf134e2 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -718,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing. require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) } + +// --------------------------------------------------------------------------- +// GetModelPricingWithChannel +// --------------------------------------------------------------------------- + +func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil) + require.NoError(t, err) + require.NotNil(t, pricing) + + // Should be identical to GetModelPricing + original, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(99e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // InputPrice overridden (both normal and priority) + require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12) + + // OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + OutputPrice: testPtrFloat64(88e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // OutputPrice overridden + require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12) + + // InputPrice unchanged (claude-sonnet-4 fallback = 3e-6) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(20e-6), + CacheWritePrice: testPtrFloat64(5e-6), + CacheReadPrice: testPtrFloat64(1e-6), + ImageOutputPrice: testPtrFloat64(50e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheWritePrice: testPtrFloat64(7e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h + require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheReadPrice: testPtrFloat64(2e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheReadPrice should set both normal and priority + require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(1e-6), + } + pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing) + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go new file mode 100644 index 00000000..694c3384 --- /dev/null +++ b/backend/internal/service/billing_service_unified_test.go @@ -0,0 +1,258 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// CalculateCostUnified +// --------------------------------------------------------------------------- + +func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: nil, // no resolver + } + cost, err := svc.CalculateCostUnified(input) + require.NoError(t, err) + + // Should match the old-path result exactly + expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil) + require.NoError(t, err) + require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) + // BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil) + require.Empty(t, cost.BillingMode) +} + +func TestCalculateCostUnified_TokenMode(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.5, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075 + expectedTotal := 1000*3e-6 + 500*15e-6 + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) + require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeToken), cost.BillingMode) +} + +func TestCalculateCostUnified_PerRequestMode(t *testing.T) { + // Set up a ChannelService with a per-request pricing channel + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 1, model: "claude-sonnet-4"}: { + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + }, + }, + channelByGroupID: map[int64]*Channel{ + 1: {ID: 1, Status: StatusActive}, + }, + groupPlatform: map[int64]string{1: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := newTestBillingService() + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(1) + + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + GroupID: &groupID, + Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50}, + RequestCount: 3, + RateMultiplier: 2.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 3 requests * $0.05 = $0.15 + require.InDelta(t, 0.15, cost.TotalCost, 1e-10) + // ActualCost = 0.15 * 2.0 = 0.30 + require.InDelta(t, 0.30, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +func TestCalculateCostUnified_ImageMode(t *testing.T) { + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 2, model: "gemini-image"}: { + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.10), + }, + }, + channelByGroupID: map[int64]*Channel{ + 2: {ID: 2, Status: StatusActive}, + }, + groupPlatform: map[int64]string{2: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{}, + } + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(2) + + input := CostInput{ + Ctx: context.Background(), + Model: "gemini-image", + GroupID: &groupID, + Tokens: UsageTokens{}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.10 = $0.20 + require.InDelta(t, 0.20, cost.TotalCost, 1e-10) + require.InDelta(t, 0.20, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeImage), cost.BillingMode) +} + +func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + costZero, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 0, // should default to 1.0 + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: -5.0, + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + require.Equal(t, "token", cost.BillingMode) +} + +func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + // Pre-resolve with per_request mode to verify it's used instead of re-resolving + preResolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + DefaultPerRequestPrice: 0.07, + } + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + Resolved: preResolved, + }) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.07 = $0.14 + require.InDelta(t, 0.14, cost.TotalCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// newTestChannelServiceWithCache creates a ChannelService with a pre-populated +// cache snapshot, bypassing the repository layer entirely. +func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService { + t.Helper() + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go index b606c748..de1f0014 100644 --- a/backend/internal/service/gateway_websearch_emulation_test.go +++ b/backend/internal/service/gateway_websearch_emulation_test.go @@ -1,8 +1,14 @@ +//go:build unit + package service import ( + "context" + "encoding/json" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -140,3 +146,235 @@ func TestBuildTextSummary_NoResults(t *testing.T) { summary := buildTextSummary("test", nil) require.Contains(t, summary, "No search results found for: test") } + +// --- shouldEmulateWebSearch --- + +// webSearchToolBody is a valid request body with exactly one web_search tool. +var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`) + +// nonWebSearchToolBody is a request body without web_search tool. +var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`) + +// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode. +func newAnthropicAPIKeyAccount(mode string) *Account { + return &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: mode}, + } +} + +// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled. +func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) { + webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{ + config: cfg, + expiresAt: time.Now().Add(10 * time.Minute).UnixNano(), + }) +} + +// clearGlobalWebSearchConfig resets the global cache to force re-read. +func clearGlobalWebSearchConfig() { + webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil)) +} + +// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config. +func newSettingServiceForWebSearchTest(enabled bool) *SettingService { + repo := newMockSettingRepo() + cfg := &WebSearchEmulationConfig{ + Enabled: enabled, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}}, + } + data, _ := json.Marshal(cfg) + repo.data[SettingKeyWebSearchEmulationConfig] = string(data) + return NewSettingService(repo, &config.Config{}) +} + +// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel. +func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService { + svc := &ChannelService{} + cache := &channelCache{ + channelByGroupID: map[int64]*Channel{groupID: ch}, + byID: map[int64]*Channel{ch.ID: ch}, + groupPlatform: map[int64]string{}, + loadedAt: time.Now(), + } + svc.cache.Store(cache) + return svc +} + +func TestShouldEmulateWebSearch_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody)) +} + +func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + // Global config disabled + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: false, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(false) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDisabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + // nil groupID + default mode → falls through to channel check → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc, channelService: nil} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + // nil channelService + default mode → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} diff --git a/backend/internal/service/notify_email_entry_test.go b/backend/internal/service/notify_email_entry_test.go new file mode 100644 index 00000000..0f4bb12e --- /dev/null +++ b/backend/internal/service/notify_email_entry_test.go @@ -0,0 +1,156 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ---------- ParseNotifyEmails ---------- + +func TestParseNotifyEmails_EmptyString(t *testing.T) { + result := ParseNotifyEmails("") + require.Nil(t, result) +} + +func TestParseNotifyEmails_EmptyArray(t *testing.T) { + result := ParseNotifyEmails("[]") + require.Nil(t, result) +} + +func TestParseNotifyEmails_Null(t *testing.T) { + // "null" is valid JSON that unmarshals into a nil string slice. + // The old-format branch then returns an empty (non-nil) slice. + result := ParseNotifyEmails("null") + require.Empty(t, result) +} + +func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) { + result := ParseNotifyEmails(" ") + require.Nil(t, result) +} + +func TestParseNotifyEmails_OldFormat(t *testing.T) { + raw := `["alice@example.com", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.False(t, result[0].Verified, "old format emails should default to unverified") + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.False(t, result[1].Disabled) +} + +func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) { + raw := `["alice@example.com", "", " ", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + require.Equal(t, "alice@example.com", result[0].Email) + require.Equal(t, "bob@example.com", result[1].Email) +} + +func TestParseNotifyEmails_NewFormat(t *testing.T) { + raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.True(t, result[0].Verified) + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.True(t, result[1].Disabled) +} + +func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) { + raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "solo@example.com", result[0].Email) + require.True(t, result[0].Verified) +} + +func TestParseNotifyEmails_InvalidJSON(t *testing.T) { + result := ParseNotifyEmails(`{not valid json`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) { + // A plain JSON object (not array) should return nil. + result := ParseNotifyEmails(`{"email":"a@b.com"}`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_WhitespacePadding(t *testing.T) { + raw := ` ["padded@example.com"] ` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "padded@example.com", result[0].Email) +} + +// ---------- MarshalNotifyEmails ---------- + +func TestMarshalNotifyEmails_EmptySlice(t *testing.T) { + result := MarshalNotifyEmails([]NotifyEmailEntry{}) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_NilSlice(t *testing.T) { + result := MarshalNotifyEmails(nil) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_SingleEntry(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "test@example.com", Verified: true, Disabled: false}, + } + result := MarshalNotifyEmails(entries) + require.Contains(t, result, `"email":"test@example.com"`) + require.Contains(t, result, `"verified":true`) + require.Contains(t, result, `"disabled":false`) + + // Round-trip: parsing the marshalled result should produce the original entries. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 1) + require.Equal(t, entries[0], parsed[0]) +} + +func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "a@example.com", Verified: true, Disabled: false}, + {Email: "b@example.com", Verified: false, Disabled: true}, + } + result := MarshalNotifyEmails(entries) + + // Round-trip verification. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 2) + require.Equal(t, entries[0], parsed[0]) + require.Equal(t, entries[1], parsed[1]) +} + +func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) { + original := []NotifyEmailEntry{ + {Email: "x@example.com", Verified: true, Disabled: true}, + {Email: "y@example.com", Verified: false, Disabled: false}, + } + marshalled := MarshalNotifyEmails(original) + parsed := ParseNotifyEmails(marshalled) + require.Equal(t, original, parsed) +} + +// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ---------- + +func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) { + // Emails with leading/trailing whitespace in old format should be trimmed. + raw := `[" alice@example.com "]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "alice@example.com", result[0].Email) +} diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 8cd50d0d..c5b96e01 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -1,9 +1,12 @@ +//go:build unit + package service import ( "context" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -141,3 +144,123 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { _ = SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "secret", cfg.Providers[0].APIKey) } + +// --- PopulateWebSearchUsage --- + +func TestPopulateWebSearchUsage_NilInput(t *testing.T) { + require.Nil(t, PopulateWebSearchUsage(context.Background(), nil)) +} + +func TestPopulateWebSearchUsage_NoManager_QuotaUsedZero(t *testing.T) { + // Ensure no global manager is set + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out) + require.Len(t, out.Providers, 1) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_True(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key"}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_False(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: ""}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.False(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_NilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: nil}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Nil(t, out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_NonNilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(500)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out.Providers[0].QuotaLimit) + require.Equal(t, int64(500), *out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_WithManager_NilRedis(t *testing.T) { + // Manager with nil Redis returns 0 usage without error + mgr := websearch.NewManager([]websearch.ProviderConfig{ + {Type: "brave", APIKey: "k"}, + }, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_DoesNotMutateOriginal(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(100)}, + }, + } + _ = PopulateWebSearchUsage(context.Background(), cfg) + // Original should be unchanged + require.Equal(t, "secret", cfg.Providers[0].APIKey) + require.Equal(t, int64(0), cfg.Providers[0].QuotaUsed) +} + +// --- ResetWebSearchUsage --- + +func TestResetWebSearchUsage_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + err := ResetWebSearchUsage(context.Background(), "brave") + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +} From d6965b0676eadba10ba499436302ccb8610af420 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 10:18:39 +0800 Subject: [PATCH 085/122] fix: resolve cherry-pick conflicts and restore compilation - Restore gateway_cache.go to upstream (no lua embeds) - Restore payment_order.go to upstream (use out_trade_no lookup) - Restore payment_fulfillment.go to upstream (same reason) - Add FeaturesConfig field and IsWebSearchEmulationEnabled to Channel - Add applyAccountStatsCost wrapper function - Add SettingKeyWebSearchEmulationConfig constant - Add WebSearchEmulationEnabled to SystemSettings - Add notify code rate limiting methods to EmailCache interface - Remove AllowUserRefund references (ent schema not present) - Fix duplicate import in payment_handler.go - Fix wire_gen.go argument mismatches --- backend/cmd/server/wire_gen.go | 4 +- backend/internal/handler/payment_handler.go | 1 - backend/internal/repository/gateway_cache.go | 257 +----------------- .../internal/service/account_stats_pricing.go | 21 ++ backend/internal/service/channel.go | 16 +- backend/internal/service/domain_constants.go | 3 + backend/internal/service/email_service.go | 4 + .../service/payment_config_providers.go | 21 +- .../service/payment_config_service.go | 2 - .../internal/service/payment_fulfillment.go | 14 +- backend/internal/service/payment_order.go | 254 +---------------- backend/internal/service/settings_view.go | 3 + 12 files changed, 80 insertions(+), 520 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 69daeecf..a0e84f4c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -143,7 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) @@ -217,8 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index e01a2af1..0425fc49 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -7,7 +7,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index ec4bf40e..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,42 +2,14 @@ package repository import ( "context" - _ "embed" "fmt" - "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) -const ( - stickySessionPrefix = "sticky_session:" - clientAffinityPrefix = "client_affinity:" - clientAffinityReversePrefix = "client_affinity_rev:" -) - -var ( - //go:embed lua/get_affinity.lua - getAffinityLua string - //go:embed lua/update_affinity.lua - updateAffinityLua string - //go:embed lua/get_affinity_count.lua - getAffinityCountLua string - //go:embed lua/get_affinity_clients.lua - getAffinityClientsLua string - //go:embed lua/get_affinity_clients_with_scores.lua - getAffinityClientsWithScoresLua string - //go:embed lua/clear_account_affinity.lua - clearAccountAffinityLua string - - getAffinityScript = redis.NewScript(getAffinityLua) - updateAffinityScript = redis.NewScript(updateAffinityLua) - getAffinityCountScript = redis.NewScript(getAffinityCountLua) - getAffinityClientsScript = redis.NewScript(getAffinityClientsLua) - getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua) - clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua) -) +const stickySessionPrefix = "sticky_session:" type gatewayCache struct { rdb *redis.Client @@ -47,16 +19,6 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } -// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。 -// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失, -// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。 -func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) { - exists, err := script.Exists(ctx, rdb).Result() - if err != nil || len(exists) == 0 || !exists[0] { - _ = script.Load(ctx, rdb).Err() - } -} - // buildSessionKey 构建 session key,包含 groupID 实现分组隔离 // 格式: sticky_session:{groupID}:{sessionHash} func buildSessionKey(groupID int64, sessionHash string) string { @@ -79,218 +41,13 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 +// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, +// 以便下次请求能够重新选择可用账号。 +// +// DeleteSessionAccountID removes the sticky session binding for the given session. +// Called when the bound account becomes unavailable (e.g., error status, disabled, +// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// buildAffinityKey 构建正向亲和 key(client → accounts) -// 格式: client_affinity:{groupID}:{clientID} -func buildAffinityKey(groupID int64, clientID string) string { - return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID) -} - -// buildAffinityReverseKey 构建反向亲和 key(account → clients) -// 格式: client_affinity_rev:{groupID}:{accountID} -func buildAffinityReverseKey(groupID int64, accountID int64) string { - return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID) -} - -func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) { - key := buildAffinityKey(groupID, clientID) - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice() - if err != nil { - if err == redis.Nil { - return nil, nil - } - return nil, err - } - - accountIDs := make([]int64, 0, len(result)) - for _, s := range result { - id, err := strconv.ParseInt(s, 10, 64) - if err != nil { - continue - } - accountIDs = append(accountIDs, id) - } - return accountIDs, nil -} - -func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error { - fwdKey := buildAffinityKey(groupID, clientID) - revKey := buildAffinityReverseKey(groupID, accountID) - now := time.Now().Unix() - ttlSeconds := int64(ttl.Seconds()) - expireThreshold := now - ttlSeconds - - return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey}, - now, ttlSeconds, accountID, expireThreshold, clientID, - ).Err() -} - -// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员) -func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) { - if len(accountIDs) == 0 { - return map[int64]int64{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(accountIDs)) - for i, accID := range accountIDs { - key := buildAffinityReverseKey(groupID, accID) - cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - result := make(map[int64]int64, len(accountIDs)) - for i, accID := range accountIDs { - count, _ := cmds[i].Int64() - result[accID] = count - } - return result, nil -} - -// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。 -// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。 -func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) { - if len(accountGroups) == 0 { - return map[int64][]string{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - // 构建所有 (accountID, groupID) 组合的查询 - type queryItem struct { - accountID int64 - groupID int64 - } - var queries []queryItem - for accID, groupIDs := range accountGroups { - for _, gID := range groupIDs { - queries = append(queries, queryItem{accountID: accID, groupID: gID}) - } - } - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(queries)) - for i, q := range queries { - key := buildAffinityReverseKey(q.groupID, q.accountID) - cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重 - result := make(map[int64][]string, len(accountGroups)) - seen := make(map[int64]map[string]struct{}, len(accountGroups)) - for i, q := range queries { - clients, _ := cmds[i].StringSlice() - if len(clients) == 0 { - continue - } - if seen[q.accountID] == nil { - seen[q.accountID] = make(map[string]struct{}) - } - for _, clientID := range clients { - if _, exists := seen[q.accountID][clientID]; !exists { - seen[q.accountID][clientID] = struct{}{} - result[q.accountID] = append(result[q.accountID], clientID) - } - } - } - return result, nil -} - -// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。 -func (c *gatewayCache) GetAccountAffinityClientsWithScores( - ctx context.Context, - accountID int64, - groupIDs []int64, - ttl time.Duration, -) ([]service.AffinityClient, error) { - if len(groupIDs) == 0 { - return nil, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(groupIDs)) - for i, gID := range groupIDs { - key := buildAffinityReverseKey(gID, accountID) - cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并跨组结果,同一 clientID 取最近的 lastActive - seen := make(map[string]int64) // clientID → max timestamp - for _, cmd := range cmds { - vals, _ := cmd.StringSlice() - // vals 格式: [clientID1, score1, clientID2, score2, ...] - for j := 0; j+1 < len(vals); j += 2 { - clientID := vals[j] - ts, _ := strconv.ParseInt(vals[j+1], 10, 64) - if existing, ok := seen[clientID]; !ok || ts > existing { - seen[clientID] = ts - } - } - } - - result := make([]service.AffinityClient, 0, len(seen)) - for clientID, ts := range seen { - result = append(result, service.AffinityClient{ - ClientID: clientID, - LastActive: time.Unix(ts, 0), - }) - } - - // 按最后活跃时间降序排序 - service.SortAffinityClients(result) - - return result, nil -} - -// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。 -// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端, -// 从每个客户端的正向索引中移除该账号,然后删除反向索引。 -func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error { - if len(groupIDs) == 0 { - return nil - } - - ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript) - - pipe := c.rdb.Pipeline() - for _, gID := range groupIDs { - revKey := buildAffinityReverseKey(gID, accountID) - clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return err - } - return nil -} diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 61c318d9..47b7496f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -227,3 +227,24 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) * } return &cost } + +// applyAccountStatsCost resolves the account stats cost for a usage log entry. +// It resolves the upstream model (falling back to the requested model) and calls +// the 4-level priority chain via resolveAccountStatsCost. +func applyAccountStatsCost( + ctx context.Context, + usageLog *UsageLog, + cs *ChannelService, bs *BillingService, + accountID int64, groupID int64, + upstreamModel, requestedModel string, + tokens UsageTokens, + totalCost float64, +) { + model := upstreamModel + if model == "" { + model = requestedModel + } + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ) +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index b034fda0..b3fb2eac 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -39,7 +39,8 @@ type Channel struct { Status string BillingModelSource string // "requested", "upstream", or "channel_mapped" RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) - Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation) CreatedAt time.Time UpdatedAt time.Time @@ -222,6 +223,19 @@ func (c *Channel) Clone() *Channel { return &cp } +// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。 +func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool { + if c == nil || c.FeaturesConfig == nil { + return false + } + wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any) + if !ok { + return false + } + enabled, ok := wse[platform].(bool) + return ok && enabled +} + // deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. func deepCopyFeaturesConfig(src map[string]any) map[string]any { dst := make(map[string]any, len(src)) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index bdced29a..cb452efb 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -258,6 +258,9 @@ const ( // Account Quota Notification SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) + + // Web Search Emulation + SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置 ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 9cfd3bbd..9a03ea30 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -49,6 +49,10 @@ type EmailCache interface { // Returns true if in cooldown period (email was sent recently) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error + + // Notify code rate limiting per user + IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) + GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) } // VerificationCodeData represents verification code data diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 0c71ab29..072ed002 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -30,7 +30,6 @@ type ProviderInstanceResponse struct { Limits string `json:"limits"` Enabled bool `json:"enabled"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` SortOrder int `json:"sort_order"` PaymentMode string `json:"payment_mode"` } @@ -47,7 +46,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) @@ -111,12 +110,10 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } - allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). - SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -224,21 +221,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) - // Cascade: turning off refund_enabled also disables allow_user_refund - if !*req.RefundEnabled { - u.SetAllowUserRefund(false) - } - } - if req.AllowUserRefund != nil { - // Only allow enabling when refund_enabled is true - if *req.AllowUserRefund { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil && inst.RefundEnabled { - u.SetAllowUserRefund(true) - } - } else { - u.SetAllowUserRefund(false) - } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -250,7 +232,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). Where( - paymentproviderinstance.AllowUserRefundEQ(true), paymentproviderinstance.RefundEnabledEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index cce31f4d..6d470342 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -114,7 +114,6 @@ type CreateProviderInstanceRequest struct { SortOrder int `json:"sort_order"` Limits string `json:"limits"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { @@ -126,7 +125,6 @@ type UpdateProviderInstanceRequest struct { SortOrder *int `json:"sort_order"` Limits *string `json:"limits"` RefundEnabled *bool `json:"refund_enabled"` - AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 51307849..de41d742 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "math" + "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -20,11 +22,17 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme if n.Status != payment.NotificationStatusSuccess { return nil } - oid, err := parseOrderID(n.OrderID) + // Look up order by out_trade_no (the external order ID we sent to the provider) + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) if err != nil { - return fmt.Errorf("invalid order ID: %s", n.OrderID) + // Fallback: try legacy format (sub2_N where N is DB ID) + trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) + if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + } + return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) } - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) } func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index e81af3f5..ff4dfaa8 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -10,7 +10,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -170,68 +169,6 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return nil } -func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error { - if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 { - return nil - } - windowStart := cancelRateLimitWindowStart(cfg) - operator := fmt.Sprintf("user:%d", userID) - count, err := s.entClient.PaymentAuditLog.Query(). - Where( - paymentauditlog.ActionEQ("ORDER_CANCELLED"), - paymentauditlog.OperatorEQ(operator), - paymentauditlog.CreatedAtGTE(windowStart), - ).Count(ctx) - if err != nil { - slog.Error("check cancel rate limit failed", "userID", userID, "error", err) - return nil // fail open - } - if count >= cfg.CancelRateLimitMax { - return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited"). - WithMetadata(map[string]string{ - "max": strconv.Itoa(cfg.CancelRateLimitMax), - "window": strconv.Itoa(cfg.CancelRateLimitWindow), - "unit": cfg.CancelRateLimitUnit, - }) - } - return nil -} - -func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time { - now := time.Now() - w := cfg.CancelRateLimitWindow - if w <= 0 { - w = 1 - } - unit := cfg.CancelRateLimitUnit - if unit == "" { - unit = "day" - } - if cfg.CancelRateLimitMode == "fixed" { - switch unit { - case "minute": - t := now.Truncate(time.Minute) - return t.Add(-time.Duration(w-1) * time.Minute) - case "day": - y, m, d := now.Date() - t := time.Date(y, m, d, 0, 0, 0, 0, now.Location()) - return t.AddDate(0, 0, -(w - 1)) - default: // hour - t := now.Truncate(time.Hour) - return t.Add(-time.Duration(w-1) * time.Hour) - } - } - // rolling window - switch unit { - case "minute": - return now.Add(-time.Duration(w) * time.Minute) - case "day": - return now.AddDate(0, 0, -w) - default: // hour - return now.Add(-time.Duration(w) * time.Hour) - } -} - func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil @@ -252,19 +189,16 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user } func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { - s.EnsureProviders(ctx) - providerKey := s.registry.GetProviderKey(req.PaymentType) - if providerKey == "" { - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) - } - sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) + // Select an instance across all providers that support the requested payment type. + // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). + sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) if err != nil { - return nil, fmt.Errorf("select provider instance: %w", err) + return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) } if sel == nil { return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") } - prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config) + prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") } @@ -272,7 +206,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen outTradeNo := order.OutTradeNo pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) if err != nil { - slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err) + slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) } _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) @@ -357,6 +291,13 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or if p.PaymentType != "" { q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType)) } + if p.Keyword != "" { + q = q.Where(paymentorder.Or( + paymentorder.OutTradeNoContainsFold(p.Keyword), + paymentorder.UserEmailContainsFold(p.Keyword), + paymentorder.UserNameContainsFold(p.Keyword), + )) + } total, err := q.Clone().Count(ctx) if err != nil { return nil, 0, fmt.Errorf("count admin orders: %w", err) @@ -368,172 +309,3 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or } return orders, total, nil } - -// --- Cancel & Expire --- - -func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order") -} - -func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order") -} - -func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) { - if o.PaymentTradeNo != "" || o.PaymentType != "" { - if s.checkPaid(ctx, o) == "already_paid" { - return "already_paid", nil - } - } - c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx) - if err != nil { - return "", fmt.Errorf("update order status: %w", err) - } - if c > 0 { - auditAction := "ORDER_CANCELLED" - if fs == OrderStatusExpired { - auditAction = "ORDER_EXPIRED" - } - s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad}) - } - return "cancelled", nil -} - -func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { - prov, err := s.getOrderProvider(ctx, o) - if err != nil { - return "" - } - // Use OutTradeNo as fallback when PaymentTradeNo is empty - // (e.g. EasyPay popup mode where trade_no arrives only via notify callback) - tradeNo := o.PaymentTradeNo - if tradeNo == "" { - tradeNo = o.OutTradeNo - } - resp, err := prov.QueryOrder(ctx, tradeNo) - if err != nil { - slog.Warn("query upstream failed", "orderID", o.ID, "error", err) - return "" - } - if resp.Status == payment.ProviderStatusPaid { - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { - slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) - // Still return already_paid — order was paid, fulfillment can be retried - } - return "already_paid" - } - if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, tradeNo) - } - return "" -} - -// VerifyOrderByOutTradeNo actively queries the upstream provider to check -// if a payment was made, and processes it if so. This handles the case where -// the provider's notify callback was missed (e.g. EasyPay popup mode). -func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - // Only verify orders that are still pending or recently expired - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - // Reload order to get updated status - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -// VerifyOrderPublic verifies payment status without user authentication. -// Used by the payment result page when the user's session has expired. -func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { - now := time.Now() - orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) - if err != nil { - return 0, fmt.Errorf("query expired: %w", err) - } - n := 0 - for _, o := range orders { - // Check upstream payment status before expiring — the user may have - // paid just before timeout and the webhook hasn't arrived yet. - outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired") - if outcome == "already_paid" { - slog.Info("order was paid during expiry", "orderID", o.ID) - continue - } - if outcome != "" { - n++ - } - } - return n, nil -} - -// getOrderProvider creates a provider using the order's original instance config. -// Falls back to registry lookup if instance ID is missing (legacy orders). -func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } - } - s.EnsureProviders(ctx) - return s.registry.GetProvider(o.PaymentType) -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ec20fe0a..ab2eb274 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -107,6 +107,9 @@ type SystemSettings struct { EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + // Web Search Emulation + WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 From 24e16b7f599eba35e180ae2e53b5e042775a6421 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 10:58:51 +0800 Subject: [PATCH 086/122] fix: restore resolveOpenAIMessagesDispatchMappedModel and reset VERSION - Restore function deleted during cherry-pick conflict resolution - Reset VERSION to upstream 0.1.112 --- backend/cmd/server/VERSION | 2 +- backend/internal/handler/openai_gateway_handler.go | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 630554d9..4b9b35d8 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.112.4 +0.1.112 diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index dda6d2e3..6c5a6779 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode return strings.TrimSpace(apiKey.Group.DefaultMappedModel) } +func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, From b42f34c359251bd374d957092915a7e79616e0db Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 11:27:32 +0800 Subject: [PATCH 087/122] fix: resolve test compilation errors and restore upstream VERSION - Add missing interface methods to test stubs (RemoveGroupFromUserAllowedGroups, GetNotifyCodeUserRate, IncrNotifyCodeUserRate, UpdateGroupIDByUserAndGroup) - Fix NewUserService call signatures (add 4th param) - Fix GetAccountCount return signature (3 values) - Update api_contract_test.go snapshots for balance_notify fields - Restore resolveOpenAIMessagesDispatchMappedModel function - Reset VERSION to upstream 0.1.112 --- backend/internal/server/api_contract_test.go | 14 ++++++++++++-- .../internal/server/middleware/admin_auth_test.go | 2 +- .../internal/server/middleware/jwt_auth_test.go | 2 +- .../internal/service/admin_service_apikey_test.go | 8 +++++++- .../internal/service/auth_service_register_test.go | 8 ++++++++ backend/internal/service/user_service_test.go | 3 +++ 6 files changed, 32 insertions(+), 5 deletions(-) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 08291faa..44c3f0e4 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) { "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z", + "balance_notify_enabled": false, + "balance_notify_threshold_type": "", + "balance_notify_threshold": null, + "balance_notify_extra_emails": null, + "total_recharged": 0, "run_mode": "standard" } }`, @@ -606,7 +611,12 @@ func TestAPIContracts(t *testing.T) { "payment_cancel_rate_limit_max": 0, "payment_cancel_rate_limit_window": 0, "payment_cancel_rate_limit_unit": "", - "payment_cancel_rate_limit_window_mode": "" + "payment_cancel_rate_limit_window_mode": "", + "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, + "balance_low_notify_threshold": 0, + "balance_low_notify_recharge_url": "", + "account_quota_notify_emails": [] } }`, }, @@ -699,7 +709,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index aafe4a58..ed2578c8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { return &clone, nil }, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) router := gin.New() router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index ad9c1b5b..c483a51e 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer userRepo := &stubJWTUserRepo{users: users} authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) - userSvc := service.NewUserService(userRepo, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) r := gin.New() diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 5c18a438..7f0a24da 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -70,6 +70,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected") +} // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -152,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { + panic("unexpected") +} // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. type groupRepoStubForGroupUpdate struct { @@ -194,7 +200,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 0999b4f0..103bafe7 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -119,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai return nil } +func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + return 0, nil +} + +func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + return 0, nil +} + func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 29267c19..a998d5f4 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -49,6 +49,9 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} // --- mock: APIKeyAuthCacheInvalidator --- From 4aa0070e3d0ac672e514713eec5b8194ce4160b2 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 11:31:44 +0800 Subject: [PATCH 088/122] fix: Stripe payment type matching in load balancer Checkout page aggregates Stripe sub-types (card,link,alipay,wxpay) under "stripe", but SelectInstance matched against supported_types literally, which doesn't contain "stripe". Now matches by provider_key for Stripe. --- backend/internal/payment/load_balancer.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index 55cb2043..f0353173 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( var matched []*dbent.PaymentProviderInstance for _, inst := range instances { - if InstanceSupportsType(inst.SupportedTypes, paymentType) { + // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), + // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". + if paymentType == TypeStripe { + if inst.ProviderKey == TypeStripe { + matched = append(matched, inst) + } + } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { matched = append(matched, inst) } } From 6a08efeef9d013d4e5b3551b8c800adca267db63 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 12:11:08 +0800 Subject: [PATCH 089/122] fix: resolve upstream CI failures (lint, test, gofmt) - Fix errcheck: handle Write/Encode return values in brave_test.go - Fix errcheck: defer resp.Body.Close() with _ assignment in tavily.go - Fix gofmt: payment.go, channel.go, payment_config_providers.go - Fix unused: remove dead decodeURLValue in easypay.go - Restore shouldFallbackGeminiModel function (deleted during cherry-pick) - Add missing balanceNotifyService param to NewGatewayService in test - Fix platform default test expectation (empty stays empty) - Fix wildcard pricing test (longest prefix wins, not config order) - Fix subscription group test (SUBSCRIPTION_REPOSITORY_UNAVAILABLE) --- .../handler/admin/channel_handler_test.go | 4 ++-- ...eway_handler_warmup_intercept_unit_test.go | 1 + .../internal/handler/gemini_v1beta_handler.go | 10 ++++++++++ backend/internal/payment/provider/easypay.go | 9 --------- backend/internal/pkg/websearch/brave_test.go | 10 +++++----- backend/internal/pkg/websearch/tavily.go | 2 +- backend/internal/server/routes/payment.go | 1 - .../service/account_stats_pricing_test.go | 4 ++-- .../service/admin_service_apikey_test.go | 4 ++-- backend/internal/service/channel.go | 4 ++-- .../service/payment_config_providers.go | 20 +++++++++---------- 11 files changed, 35 insertions(+), 34 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 2f4b4440..f218cce4 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) { wantValue: string(service.BillingModeToken), }, { - name: "empty platform defaults to anthropic", + name: "empty platform stays empty", req: channelModelPricingRequest{ Models: []string{"m1"}, Platform: "", }, wantField: "Platform", - wantValue: "anthropic", + wantValue: "", }, } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index acea3780..1fdc46ba 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // tlsFPProfileService nil, // channelService nil, // resolver + nil, // balanceNotifyService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 45b5842f..6b8cc482 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -682,6 +682,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index c54aba6a..b48a38fe 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -276,12 +276,3 @@ func easyPaySign(params map[string]string, pkey string) string { func easyPayVerifySign(params map[string]string, pkey string, sign string) bool { return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign)) } - -// decodeURLValue URL-decodes a string once. -func decodeURLValue(s string) string { - decoded, err := url.QueryUnescape(s) - if err != nil { - return s - } - return decoded -} diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go index 3fe35020..4dc5b219 100644 --- a/backend/internal/pkg/websearch/brave_test.go +++ b/backend/internal/pkg/websearch/brave_test.go @@ -29,7 +29,7 @@ func TestBraveProvider_Search_Success(t *testing.T) { {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"}, } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -53,7 +53,7 @@ func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedCount = r.URL.Query().Get("count") resp := braveResponse{} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -70,7 +70,7 @@ func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) { func TestBraveProvider_Search_HTTPError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(429) - w.Write([]byte("rate limited")) + _, _ = w.Write([]byte("rate limited")) })) defer srv.Close() @@ -86,7 +86,7 @@ func TestBraveProvider_Search_HTTPError(t *testing.T) { func TestBraveProvider_Search_InvalidJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("not json")) + _, _ = w.Write([]byte("not json")) })) defer srv.Close() @@ -103,7 +103,7 @@ func TestBraveProvider_Search_InvalidJSON(t *testing.T) { func TestBraveProvider_Search_EmptyResults(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go index 6ac09edf..ac4928a6 100644 --- a/backend/internal/pkg/websearch/tavily.go +++ b/backend/internal/pkg/websearch/tavily.go @@ -60,7 +60,7 @@ func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*Search if err != nil { return nil, fmt.Errorf("tavily: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) if err != nil { diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 641c6cd5..72012a4e 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -78,7 +78,6 @@ func RegisterPaymentRoutes( adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund) } - // Subscription Plans plans := adminGroup.Group("/plans") { diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 36e5eb74..2f625393 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "wildcard matches by config order (first match wins)", + name: "wildcard matches by longest prefix (most specific wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 10, // config order: "claude-*" is first and matches, so it wins + wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 7f0a24da..1e235278 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -412,10 +412,10 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *test userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} - // 订阅类型分组应被阻止绑定 + // userSubRepo is nil → SUBSCRIPTION_REPOSITORY_UNAVAILABLE _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index b3fb2eac..93beb972 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -37,8 +37,8 @@ type Channel struct { Name string Description string Status string - BillingModelSource string // "requested", "upstream", or "channel_mapped" - RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + BillingModelSource string // "requested", "upstream", or "channel_mapped" + RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) Features string // 渠道特性描述(JSON 数组),用于支付页面展示 FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation) CreatedAt time.Time diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 072ed002..47008df0 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,16 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. From e8ee400a3f8720345f025f8ba8e6164d0950592c Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 12:19:44 +0800 Subject: [PATCH 090/122] fix: resolve remaining lint errors for upstream CI - Fix errcheck: brave.go resp.Body.Close, manager_test.go Encode - Fix gofmt: payment_config_service.go - Fix unused: use shouldFallbackGeminiModel (with modelName param) in handler --- .../internal/handler/gemini_v1beta_handler.go | 2 +- backend/internal/pkg/websearch/brave.go | 2 +- .../internal/pkg/websearch/manager_test.go | 4 +-- .../service/payment_config_service.go | 34 +++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 6b8cc482..d200c17c 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go index 5620ca8d..707e7029 100644 --- a/backend/internal/pkg/websearch/brave.go +++ b/backend/internal/pkg/websearch/brave.go @@ -62,7 +62,7 @@ func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchR if err != nil { return nil, fmt.Errorf("brave: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) if err != nil { diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index cbcf1b76..a4413417 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -50,7 +50,7 @@ func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) { srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srvBrave.Close() @@ -77,7 +77,7 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 6d470342..9042c3ab 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,26 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` From f1297a3694973d7ba9274d7b5e5874aafa1dfa56 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 16:26:46 +0800 Subject: [PATCH 091/122] feat: add per-provider allow_user_refund control and align wildcard matching allow_user_refund: - Add allow_user_refund field to PaymentProviderInstance ent schema - Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN - Cascade logic: disabling refund_enabled auto-disables allow_user_refund - User refund validation: check provider instance allows user refund - Admin refund validation: check provider instance allows admin refund - Subscription refund: deduct days on refund, rollback on failure - New endpoint: GET /payment/orders/refund-eligible-providers - Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView Wildcard matching: - Change findPricingForModel from "longest prefix wins" to "config order priority (first match wins)", aligning with channel service behavior --- backend/ent/client.go | 36 +++---- backend/ent/intercept/intercept.go | 1 - backend/ent/migrate/schema.go | 1 + backend/ent/mutation.go | 94 +++++++++++++++---- backend/ent/paymentproviderinstance.go | 13 ++- .../paymentproviderinstance.go | 10 ++ backend/ent/paymentproviderinstance/where.go | 15 +++ backend/ent/paymentproviderinstance_create.go | 65 +++++++++++++ backend/ent/paymentproviderinstance_update.go | 34 +++++++ backend/ent/predicate/predicate.go | 1 - backend/ent/runtime/runtime.go | 8 +- .../ent/schema/payment_provider_instance.go | 2 + backend/internal/handler/payment_handler.go | 10 ++ backend/internal/server/routes/payment.go | 1 + .../internal/service/account_stats_pricing.go | 22 +---- .../service/account_stats_pricing_test.go | 4 +- .../service/payment_config_providers.go | 42 ++++++--- .../service/payment_config_service.go | 36 +++---- backend/internal/service/payment_refund.go | 64 +++++++++++++ .../migrations/103_add_allow_user_refund.sql | 1 + frontend/src/api/payment.ts | 5 + .../payment/PaymentProviderDialog.vue | 8 +- .../payment/PaymentProviderList.vue | 2 +- .../src/components/payment/ProviderCard.vue | 3 +- frontend/src/i18n/locales/en.ts | 1 + frontend/src/i18n/locales/zh.ts | 1 + frontend/src/types/payment.ts | 2 + frontend/src/views/admin/SettingsView.vue | 21 ++++- 28 files changed, 405 insertions(+), 98 deletions(-) create mode 100644 backend/migrations/103_add_allow_user_refund.sql diff --git a/backend/ent/client.go b/backend/ent/client.go index 3da7acf8..e52e015a 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 77d3e16e..8d8320bb 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -336,7 +336,6 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q) } - // The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 1fff61ba..68bdbf55 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -616,6 +616,7 @@ var ( {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "refund_enabled", Type: field.TypeBool, Default: false}, + {Name: "allow_user_refund", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 3bca248d..524ccb92 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error { // PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. type PaymentProviderInstanceMutation struct { config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance } var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) @@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { m.refund_enabled = nil } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return + } + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + } + return oldValue.AllowUserRefund, nil +} + +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil +} + // SetCreatedAt sets the "created_at" field. func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.provider_key != nil { fields = append(fields, paymentproviderinstance.FieldProviderKey) } @@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { if m.refund_enabled != nil { fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + } if m.created_at != nil { fields = append(fields, paymentproviderinstance.FieldCreatedAt) } @@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { return m.Limits() case paymentproviderinstance.FieldRefundEnabled: return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() case paymentproviderinstance.FieldCreatedAt: return m.CreatedAt() case paymentproviderinstance.FieldUpdatedAt: @@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str return m.OldLimits(ctx) case paymentproviderinstance.FieldRefundEnabled: return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) case paymentproviderinstance.FieldCreatedAt: return m.OldCreatedAt(ctx) case paymentproviderinstance.FieldUpdatedAt: @@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) } m.SetRefundEnabled(v) return nil + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) + return nil case paymentproviderinstance.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error { case paymentproviderinstance.FieldRefundEnabled: m.ResetRefundEnabled() return nil + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() + return nil case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go index 087cb13a..4279b86e 100644 --- a/backend/ent/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance.go @@ -35,6 +35,8 @@ type PaymentProviderInstance struct { Limits string `json:"limits,omitempty"` // RefundEnabled holds the value of the "refund_enabled" field. RefundEnabled bool `json:"refund_enabled,omitempty"` + // AllowUserRefund holds the value of the "allow_user_refund" field. + AllowUserRefund bool `json:"allow_user_refund,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: + case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund: values[i] = new(sql.NullBool) case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: values[i] = new(sql.NullInt64) @@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) } else if value.Valid { _m.RefundEnabled = value.Bool } + case paymentproviderinstance.FieldAllowUserRefund: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i]) + } else if value.Valid { + _m.AllowUserRefund = value.Bool + } case paymentproviderinstance.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string { builder.WriteString("refund_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(", ") + builder.WriteString("allow_user_refund=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go index c430fef6..eb1b0c52 100644 --- a/backend/ent/paymentproviderinstance/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go @@ -31,6 +31,8 @@ const ( FieldLimits = "limits" // FieldRefundEnabled holds the string denoting the refund_enabled field in the database. FieldRefundEnabled = "refund_enabled" + // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database. + FieldAllowUserRefund = "allow_user_refund" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -51,6 +53,7 @@ var Columns = []string{ FieldSortOrder, FieldLimits, FieldRefundEnabled, + FieldAllowUserRefund, FieldCreatedAt, FieldUpdatedAt, } @@ -88,6 +91,8 @@ var ( DefaultLimits string // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. DefaultRefundEnabled bool + // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field. + DefaultAllowUserRefund bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() } +// ByAllowUserRefund orders the results by the allow_user_refund field. +func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go index 7b99517f..40e5a1f6 100644 --- a/backend/ent/paymentproviderinstance/where.go +++ b/backend/ent/paymentproviderinstance/where.go @@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) } +// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ. +func AllowUserRefund(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) @@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) } +// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field. +func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + +// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field. +func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go index 20b16ddd..d1b14617 100644 --- a/backend/ent/paymentproviderinstance_create.go +++ b/backend/ent/paymentproviderinstance_create.go @@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym return _c } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate { + _c.mutation.SetAllowUserRefund(v) + return _c +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate { + if v != nil { + _c.SetAllowUserRefund(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { _c.mutation.SetCreatedAt(v) @@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() { v := paymentproviderinstance.DefaultRefundEnabled _c.mutation.SetRefundEnabled(v) } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + v := paymentproviderinstance.DefaultAllowUserRefund + _c.mutation.SetAllowUserRefund(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := paymentproviderinstance.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error { if _, ok := _c.mutation.RefundEnabled(); !ok { return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} } @@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _node.RefundEnabled = value } + if value, ok := _c.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + _node.AllowUserRefund = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn return u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert { + u.Set(paymentproviderinstance.FieldAllowUserRefund, v) + return u +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert { + u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund) + return u +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { u.Set(paymentproviderinstance.FieldUpdatedAt, v) @@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { return u.Update(func(s *PaymentProviderInstanceUpsert) { @@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { return u.Update(func(s *PaymentProviderInstanceUpsert) { diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go index 06dba527..6bb3a82d 100644 --- a/backend/ent/paymentproviderinstance_update.go +++ b/backend/ent/paymentproviderinstance_update.go @@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { _u.mutation.SetUpdatedAt(v) @@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } @@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { _u.mutation.SetUpdatedAt(v) @@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 67f37c75..ef551940 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,7 +33,6 @@ type IdempotencyRecord func(*sql.Selector) // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) - // PaymentOrder is the predicate function for paymentorder builders. type PaymentOrder func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 951b5f99..fbdd08c7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -668,12 +668,16 @@ func init() { paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) + // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field. + paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor() + // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field. + paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool) // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. - paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() + paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor() // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. - paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() + paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor() // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go index 08ab7d31..e4c0b72c 100644 --- a/backend/ent/schema/payment_provider_instance.go +++ b/backend/ent/schema/payment_provider_instance.go @@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field { Default(""), field.Bool("refund_enabled"). Default(false), + field.Bool("allow_user_refund"). + Default(false), field.Time("created_at"). Immutable(). Default(time.Now). diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 0425fc49..5fde86fa 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { response.Success(c, gin.H{"message": "refund requested"}) } +// GetRefundEligibleProviders returns provider instance IDs that allow user refund. +func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) { + ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"provider_instance_ids": ids}) +} + // VerifyOrderRequest is the request body for verifying a payment order. type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 72012a4e..8def7559 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -37,6 +37,7 @@ func RegisterPaymentRoutes( orders.GET("/:id", paymentHandler.GetOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/refund-request", paymentHandler.RequestRefund) + orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders) } } diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 47b7496f..90ff450f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -2,7 +2,6 @@ package service import ( "context" - "sort" "strings" ) @@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int return false } -// wildcardMatch 通配符匹配候选项(用于排序) -type wildcardMatch struct { - prefixLen int - pricing *ChannelModelPricing -} - // findPricingForModel 在定价列表中查找匹配的模型定价。 -// 先精确匹配,再通配符匹配(前缀越长优先级越高)。 +// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。 func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { // 精确匹配优先 for i := range pricingList { @@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } } } - // 通配符匹配:收集所有匹配项,按前缀长度降序取最长 - var matches []wildcardMatch + // 通配符匹配:按配置顺序,先匹配先使用 for i := range pricingList { p := &pricingList[i] if !isPlatformMatch(platform, p.Platform) { @@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } prefix := strings.TrimSuffix(ml, "*") if strings.HasPrefix(modelLower, prefix) { - matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p}) + return p } } } - if len(matches) == 0 { - return nil - } - sort.Slice(matches, func(i, j int) bool { - return matches[i].prefixLen > matches[j].prefixLen - }) - return matches[0].pricing + return nil } // isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 2f625393..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "wildcard matches by longest prefix (most specific wins)", + name: "wildcard matches by config order (first match wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" + wantID: 10, // config order: "claude-*" is first and matches, so it wins }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 47008df0..0f7cb99a 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. @@ -47,7 +48,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, - SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, + AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) if err != nil { @@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } + allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) + // Cascade: turning off refund_enabled also disables allow_user_refund + if !*req.RefundEnabled { + u.SetAllowUserRefund(false) + } + } + if req.AllowUserRefund != nil { + // Only allow enabling when refund_enabled is true + if *req.AllowUserRefund { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil && inst.RefundEnabled { + u.SetAllowUserRefund(true) + } + } else { + u.SetAllowUserRefund(false) + } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.RefundEnabledEQ(true), + paymentproviderinstance.AllowUserRefundEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { return nil, err diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..cce31f4d 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,28 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` + AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 68f9c697..75d75b2f 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -17,6 +17,19 @@ import ( // --- Refund Flow --- +// getOrderProviderInstance looks up the provider instance that processed this order. +// Returns nil, nil for legacy orders without provider_instance_id. +func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + return nil, nil + } + instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + if err != nil { + return nil, nil + } + return s.entClient.PaymentProviderInstance.Get(ctx, instID) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { @@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int if o.Status != OrderStatusCompleted { return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } + // Check provider instance allows user refund + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil || inst == nil { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order") + } + if !inst.AllowUserRefund { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider") + } return o, nil } @@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float if !psSliceContains(ok, o.Status) { return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } + // Check provider instance allows admin refund + inst, instErr := s.getOrderProviderInstance(ctx, o) + if instErr != nil { + slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr) + } + if inst != nil && !inst.RefundEnabled { + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") + } + if inst == nil && instErr == nil { + // Legacy order without provider_instance_id — block refund + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order") + } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } @@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult { if o.OrderType == payment.OrderTypeSubscription { p.DeductionType = payment.DeductionTypeSubscription + if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil { + p.SubDaysToDeduct = *o.SubscriptionDays + sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID) + if err == nil && sub != nil { + p.SubscriptionID = sub.ID + } else if !force { + return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true} + } + } return nil } u, err := s.userRepo.GetByID(ctx, o.UserID) @@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref p.BalanceToDeduct = 0 } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { + _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) + if err != nil { + slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) + if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + } + } + } else { + slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID) + p.SubDaysToDeduct = 0 + } + } if err := s.gwRefund(ctx, p); err != nil { return s.handleGwFail(ctx, p, err) } @@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr return false } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil { + slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err) + s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct}) + return false + } + } return true } diff --git a/backend/migrations/103_add_allow_user_refund.sql b/backend/migrations/103_add_allow_user_refund.sql new file mode 100644 index 00000000..79525382 --- /dev/null +++ b/backend/migrations/103_add_allow_user_refund.sql @@ -0,0 +1 @@ +ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS allow_user_refund BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 1389b60f..5cedb107 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -75,5 +75,10 @@ export const paymentAPI = { /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) + }, + + /** Get provider instance IDs that allow user refund */ + getRefundEligibleProviders() { + return apiClient.get<{ provider_instance_ids: string[] }>('/payment/orders/refund-eligible-providers') } } diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue index 9b60cba1..10c1bfea 100644 --- a/frontend/src/components/payment/PaymentProviderDialog.vue +++ b/frontend/src/components/payment/PaymentProviderDialog.vue @@ -32,7 +32,8 @@
- + +
{{ t('admin.settings.payment.paymentMode') }}
@@ -243,6 +244,7 @@ const emit = defineEmits<{ enabled: boolean payment_mode: string refund_enabled: boolean + allow_user_refund: boolean config: Record limits: string }] @@ -258,6 +260,7 @@ const form = reactive({ enabled: true, payment_mode: PAYMENT_MODE_QRCODE, refund_enabled: false, + allow_user_refund: false, }) const config = reactive>({}) const limits = reactive>>({}) @@ -433,6 +436,7 @@ function handleSave() { enabled: form.enabled, payment_mode: form.provider_key === 'easypay' ? form.payment_mode : '', refund_enabled: form.refund_enabled, + allow_user_refund: form.refund_enabled ? form.allow_user_refund : false, config: filteredConfig, limits: serializeLimits(), }) @@ -452,6 +456,7 @@ function reset(defaultKey: string) { form.enabled = true form.payment_mode = defaultKey === 'easypay' ? PAYMENT_MODE_QRCODE : '' form.refund_enabled = false + form.allow_user_refund = false clearConfig() applyDefaults() } @@ -463,6 +468,7 @@ function loadProvider(provider: ProviderInstance) { form.enabled = provider.enabled form.payment_mode = provider.payment_mode || (provider.provider_key === 'easypay' ? PAYMENT_MODE_QRCODE : '') form.refund_enabled = provider.refund_enabled + form.allow_user_refund = provider.allow_user_refund clearConfig() // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••) if (provider.config) { diff --git a/frontend/src/components/payment/PaymentProviderList.vue b/frontend/src/components/payment/PaymentProviderList.vue index e942b8c4..49ebc726 100644 --- a/frontend/src/components/payment/PaymentProviderList.vue +++ b/frontend/src/components/payment/PaymentProviderList.vue @@ -115,7 +115,7 @@ const emit = defineEmits<{ create: [] edit: [provider: ProviderInstance] delete: [provider: ProviderInstance] - toggleField: [provider: ProviderInstance, field: 'enabled' | 'refund_enabled'] + toggleField: [provider: ProviderInstance, field: 'enabled' | 'refund_enabled' | 'allow_user_refund'] toggleType: [provider: ProviderInstance, type: string] reorder: [providers: { id: number; sort_order: number }[]] }>() diff --git a/frontend/src/components/payment/ProviderCard.vue b/frontend/src/components/payment/ProviderCard.vue index 9fc3b0ff..aecc8c8a 100644 --- a/frontend/src/components/payment/ProviderCard.vue +++ b/frontend/src/components/payment/ProviderCard.vue @@ -46,6 +46,7 @@
+
@@ -102,7 +105,7 @@ const { t } = useI18n() const appStore = useAppStore() const saving = ref(false) -const planForm = reactive({ name: '', group_id: null as number | null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', for_sale: true }) +const planForm = reactive({ name: '', group_id: null as number | null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', sort_order: 0, for_sale: true }) const planFeaturesText = ref('') const validityUnitOptions = computed(() => [ @@ -130,10 +133,10 @@ const selectedGroupInfo = computed(() => { watch(() => props.show, (visible) => { if (!visible) return if (props.plan) { - Object.assign(planForm, { name: props.plan.name, group_id: props.plan.group_id, description: props.plan.description, price: props.plan.price, original_price: props.plan.original_price || 0, validity_days: props.plan.validity_days, validity_unit: props.plan.validity_unit || 'days', for_sale: props.plan.for_sale }) + Object.assign(planForm, { name: props.plan.name, group_id: props.plan.group_id, description: props.plan.description, price: props.plan.price, original_price: props.plan.original_price || 0, validity_days: props.plan.validity_days, validity_unit: props.plan.validity_unit || 'days', sort_order: props.plan.sort_order || 0, for_sale: props.plan.for_sale }) planFeaturesText.value = (props.plan.features || []).join('\n') } else { - Object.assign(planForm, { name: '', group_id: null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', for_sale: true }) + Object.assign(planForm, { name: '', group_id: null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', sort_order: 0, for_sale: true }) planFeaturesText.value = '' } }) @@ -149,6 +152,7 @@ function buildPlanPayload() { original_price: planForm.original_price || 0, validity_days: planForm.validity_days, validity_unit: planForm.validity_unit, + sort_order: planForm.sort_order, for_sale: planForm.for_sale, features, } diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 3c7df572..bc16918c 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -102,10 +102,12 @@ interface ReturnInfo { } const returnInfo = ref(null) +const SUCCESS_STATUSES = new Set(['COMPLETED', 'PAID', 'RECHARGING']) + const isSuccess = computed(() => { // Always prioritize actual order status from backend if (order.value) { - return order.value.status === 'COMPLETED' || order.value.status === 'PAID' + return SUCCESS_STATUSES.has(order.value.status) } // Fallback only when order not loaded if (route.query.status === 'success') return true diff --git a/frontend/src/views/user/UserOrdersView.vue b/frontend/src/views/user/UserOrdersView.vue index 51aacf7d..ea888eb7 100644 --- a/frontend/src/views/user/UserOrdersView.vue +++ b/frontend/src/views/user/UserOrdersView.vue @@ -22,7 +22,7 @@ {{ t('payment.orders.cancel') }} - @@ -102,6 +102,7 @@ const appStore = useAppStore() const loading = ref(false) const actionLoading = ref(false) const orders = ref([]) +const refundEligibleProviders = ref>(new Set()) const currentFilter = ref('') const cancelTargetId = ref(null) const refundTarget = ref(null) @@ -171,5 +172,18 @@ async function confirmRefund() { } } -onMounted(() => fetchOrders()) +function canRequestRefund(order: PaymentOrder): boolean { + if (order.status !== 'COMPLETED') return false + if (!order.provider_instance_id) return false + return refundEligibleProviders.value.has(order.provider_instance_id) +} + +async function loadRefundEligibility() { + try { + const res = await paymentAPI.getRefundEligibleProviders() + refundEligibleProviders.value = new Set(res.data.provider_instance_ids || []) + } catch { /* ignore — default to hiding refund button */ } +} + +onMounted(() => { fetchOrders(); loadRefundEligibility() }) From 58677dd53fc8a0aa56ead0feeb87ac05b8f58018 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 18:34:57 +0800 Subject: [PATCH 093/122] fix: merge 5 PR-related improvements - gateway_handler: pass ParsedRequest to RecordUsage + set in gin.Context - channel_handler: add FeaturesConfig to CRUD (WebSearch channel toggle) - channel_repo: features_config JSONB persistence (Create/Get/Update/List) - security_headers: add Stripe CSP domains (script-src + frame-src) --- .../internal/handler/admin/channel_handler.go | 6 ++ backend/internal/handler/gateway_handler.go | 3 + backend/internal/repository/channel_repo.go | 61 ++++++++++++++----- .../server/middleware/security_headers.go | 13 +++- 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 88d27c47..9151d018 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -35,6 +35,7 @@ type createChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -49,6 +50,7 @@ type updateChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels *bool `json:"restrict_models"` Features *string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -93,6 +95,7 @@ type channelResponse struct { BillingModelSource string `json:"billing_model_source"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` GroupIDs []int64 `json:"group_ids"` ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelMapping map[string]map[string]string `json:"model_mapping"` @@ -148,6 +151,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Status: ch.Status, RestrictModels: ch.RestrictModels, Features: ch.Features, + FeaturesConfig: ch.FeaturesConfig, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -379,6 +383,7 @@ func (h *ChannelHandler) Create(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, AccountStatsPricingRules: statsRules, }) @@ -414,6 +419,7 @@ func (h *ChannelHandler) Update(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, } if req.ModelPricing != nil { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8ec54420..30065463 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -473,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: apiKey, User: apiKey.User, Account: account, @@ -675,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { @@ -813,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 583ce895..2cb90aab 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -80,11 +84,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} - var modelMappingJSON []byte + var modelMappingJSON, featuresConfigJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -92,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha return nil, fmt.Errorf("get channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -120,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW() - WHERE id = $9`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW() + WHERE id = $10`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -207,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) @@ -223,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -298,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -309,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -488,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string { return m } +func marshalFeaturesConfig(m map[string]any) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal features_config: %w", err) + } + return data, nil +} + +func unmarshalFeaturesConfig(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + // GetGroupPlatforms 批量查询分组 ID 对应的平台 func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { if len(groupIDs) == 0 { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 73210bfc..7021ab2e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -18,6 +18,8 @@ const ( NonceTemplate = "__CSP_NONCE__" // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics CloudflareInsightsDomain = "https://static.cloudflareinsights.com" + // StripeDomain is the domain for Stripe.js SDK + StripeDomain = "https://*.stripe.com" ) // GenerateNonce generates a cryptographically secure random nonce. @@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool { strings.HasPrefix(path, "/responses") } -// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. -// This allows the application to work correctly even if the config file has an older CSP policy. +// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, +// and Stripe.js domains. This allows the application to work correctly even if the +// config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { // Add nonce placeholder to script-src if not present if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") { @@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string { policy = addToDirective(policy, "script-src", CloudflareInsightsDomain) } + // Add Stripe.js domain to script-src and frame-src if not present + if !strings.Contains(policy, "stripe.com") { + policy = addToDirective(policy, "script-src", StripeDomain) + policy = addToDirective(policy, "frame-src", StripeDomain) + } + return policy } From c14d739360de12741920c01ca4bbd140559a90d1 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 18:41:09 +0800 Subject: [PATCH 094/122] fix: resolve 3 code review issues in allow_user_refund 1. PrepareRefund: block refund on provider instance lookup failure instead of silently skipping permission check (medium severity) 2. UpdateProviderInstance: allow enabling refund_enabled and allow_user_refund in the same request by checking req.RefundEnabled value before falling back to DB read 3. ExecuteRefund: only revoke subscription on ErrAdjustWouldExpire, abort on other errors (DB failure, not found) instead of unconditionally revoking --- .../service/payment_config_providers.go | 14 ++++++++--- backend/internal/service/payment_refund.go | 25 +++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 0f7cb99a..3c406b45 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -231,10 +231,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } } if req.AllowUserRefund != nil { - // Only allow enabling when refund_enabled is true + // Only allow enabling when refund_enabled is (or will be) true if *req.AllowUserRefund { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil && inst.RefundEnabled { + refundEnabled := false + if req.RefundEnabled != nil { + refundEnabled = *req.RefundEnabled + } else { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil { + refundEnabled = inst.RefundEnabled + } + } + if refundEnabled { u.SetAllowUserRefund(true) } } else { diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 75d75b2f..99468433 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log/slog" "math" @@ -93,15 +94,16 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float // Check provider instance allows admin refund inst, instErr := s.getOrderProviderInstance(ctx, o) if instErr != nil { - slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr) + slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr) + return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order") } - if inst != nil && !inst.RefundEnabled { - return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") - } - if inst == nil && instErr == nil { + if inst == nil { // Legacy order without provider_instance_id — block refund return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order") } + if !inst.RefundEnabled { + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") + } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } @@ -183,10 +185,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) if err != nil { - slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) - if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + if errors.Is(err, ErrAdjustWouldExpire) { + // Deduction would expire the subscription — revoke it entirely + slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) + if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + } + } else { + // Other errors (DB failure, not found) — abort refund s.restoreStatus(ctx, p) - return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + return nil, fmt.Errorf("deduct subscription days: %w", err) } } } else { From 63f539b3828eccbe2bf416458447acb89c6df8bd Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 19:29:37 +0800 Subject: [PATCH 095/122] fix: merge general improvements from release branch Backend: - gateway_handler: pass subject.UserID instead of int64(0) for user-level routing - setting_handler: add missing BalanceLowNotifyRechargeURL to UpdateSettings response - openai_gateway_service: use applyAccountStatsCost for account stats pricing integration - embed_on: add local file override (data/public/) for embedded frontend assets Frontend: - useTableSelection: add batchUpdate method for batch operations - AccountsView: virtual scrolling params, Set-based isSelected, swipe virtualization - ProxiesView: add batchUpdate to selection and swipe-select - BulkEditAccountModal: fix submit handler to prevent event object as argument - SettingsView: move payload construction outside try block - i18n: add general translation keys (saved, deleted, view, validation, allowUserRefund) - api/client: reorder error fields for consistency - stores/payment: clarify pollOrderStatus JSDoc --- .../internal/handler/admin/setting_handler.go | 1 + backend/internal/handler/gateway_handler.go | 2 +- .../service/openai_gateway_service.go | 11 +--- backend/internal/web/embed_on.go | 65 ++++++++++++++++--- frontend/src/api/client.ts | 2 +- .../account/BulkEditAccountModal.vue | 2 +- frontend/src/composables/useTableSelection.ts | 9 ++- frontend/src/i18n/locales/en.ts | 7 ++ frontend/src/i18n/locales/zh.ts | 6 ++ frontend/src/stores/payment.ts | 2 +- frontend/src/views/admin/AccountsView.vue | 26 ++++++-- frontend/src/views/admin/ProxiesView.vue | 6 +- frontend/src/views/admin/SettingsView.vue | 11 ++-- 13 files changed, 114 insertions(+), 36 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 9b49150c..29c97b4b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1071,6 +1071,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableCCHSigning: updatedSettings.EnableCCHSigning, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), PaymentEnabled: updatedPaymentCfg.Enabled, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 30065463..f5eff8c9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -522,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 9a6fbb8f..6087b7b6 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4575,14 +4575,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { - statsModel := result.UpstreamModel - if statsModel == "" { - statsModel = result.Model - } - usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, statsModel, - tokens, 1, cost.TotalCost, + applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, + tokens, cost.TotalCost, ) } diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index ad5ac7d8..89d09eef 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -10,6 +10,8 @@ import ( "io" "io/fs" "net/http" + "os" + "path/filepath" "strings" "time" @@ -32,11 +34,12 @@ type PublicSettingsProvider interface { // FrontendServer serves the embedded frontend with settings injection type FrontendServer struct { - distFS fs.FS - fileServer http.Handler - baseHTML []byte - cache *HTMLCache - settings PublicSettingsProvider + distFS fs.FS + fileServer http.Handler + baseHTML []byte + cache *HTMLCache + settings PublicSettingsProvider + overrideDir string // local file override directory } // NewFrontendServer creates a new frontend server with settings injection @@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer cache.SetBaseHTML(baseHTML) return &FrontendServer{ - distFS: distFS, - fileServer: http.FileServer(http.FS(distFS)), - baseHTML: baseHTML, - cache: cache, - settings: settingsProvider, + distFS: distFS, + fileServer: http.FileServer(http.FS(distFS)), + baseHTML: baseHTML, + cache: cache, + settings: settingsProvider, + overrideDir: filepath.Join("data", "public"), }, nil } @@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { return } + // Try local override first + if s.tryServeOverride(c, cleanPath) { + return + } + // Serve static files normally s.fileServer.ServeHTTP(c.Writer, c.Request) c.Abort() @@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool { return true } +// tryServeOverride checks if a local override file exists and serves it. +// Files in overrideDir take precedence over embedded files. +func (s *FrontendServer) tryServeOverride(c *gin.Context, cleanPath string) bool { + if s.overrideDir == "" { + return false + } + filePath := filepath.Join(s.overrideDir, filepath.Clean("/"+cleanPath)) + info, err := os.Stat(filePath) + if err != nil || info.IsDir() { + return false + } + c.File(filePath) + c.Abort() + return true +} + func (s *FrontendServer) serveIndexHTML(c *gin.Context) { // Get nonce from context (generated by SecurityHeaders middleware) nonce := middleware.GetNonceFromContext(c) @@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { panic("failed to get dist subdirectory: " + err.Error()) } fileServer := http.FileServer(http.FS(distFS)) + overrideDir := filepath.Join("data", "public") return func(c *gin.Context) { path := c.Request.URL.Path @@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if file, err := distFS.Open(cleanPath); err == nil { _ = file.Close() + // Try local override first + if tryServeOverrideFile(c, overrideDir, cleanPath) { + return + } fileServer.ServeHTTP(c.Writer, c.Request) c.Abort() return @@ -251,6 +281,21 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { } } +// tryServeOverrideFile is a standalone version of tryServeOverride for legacy usage. +func tryServeOverrideFile(c *gin.Context, overrideDir, cleanPath string) bool { + if overrideDir == "" { + return false + } + filePath := filepath.Join(overrideDir, filepath.Clean("/"+cleanPath)) + info, err := os.Stat(filePath) + if err != nil || info.IsDir() { + return false + } + c.File(filePath) + c.Abort() + return true +} + func shouldBypassEmbeddedFrontend(path string) bool { trimmed := strings.TrimSpace(path) return strings.HasPrefix(trimmed, "/api/") || diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 2908c6b1..8a586902 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -270,9 +270,9 @@ apiClient.interceptors.response.use( return Promise.reject({ status, code: apiData.code, + reason: apiData.reason, error: apiData.error, message: apiData.message || apiData.detail || error.message, - reason: apiData.reason, metadata: apiData.metadata, }) } diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 2934fbd9..5461015b 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -5,7 +5,7 @@ width="wide" @close="handleClose" > -
+

diff --git a/frontend/src/composables/useTableSelection.ts b/frontend/src/composables/useTableSelection.ts index a65144a9..f0e096ff 100644 --- a/frontend/src/composables/useTableSelection.ts +++ b/frontend/src/composables/useTableSelection.ts @@ -76,6 +76,12 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions) => void) => { + const draft = new Set(selectedSet.value) + updater(draft) + replaceSelectedSet(draft) + } + const selectVisible = () => { toggleVisible(true) } @@ -93,6 +99,7 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions { return response.data } - /** Poll order status by ID */ + /** Poll order status by ID (read-only, no upstream check) */ async function pollOrderStatus(orderId: number): Promise { try { const response = await paymentAPI.getOrder(orderId) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index d7fae112..4fec956b 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -144,6 +144,7 @@