From f68909a68ba31008f1d30f9f42069b6326f36525 Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Tue, 21 Apr 2026 08:54:18 +0800
Subject: [PATCH 01/33] fix: reconcile openai admin test rate-limit state
---
.../internal/service/account_test_service.go | 36 +++++++++++++++++++
.../account_test_service_openai_test.go | 19 ++++++----
backend/internal/service/ratelimit_service.go | 6 +++-
3 files changed, 54 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index e5bc93ca..bb2fb8c0 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.reconcileOpenAI429State(ctx, account, resp.Header, body)
+ }
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
@@ -550,6 +553,39 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
+func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
+ if s == nil || s.accountRepo == nil || account == nil {
+ return
+ }
+
+ var resetAt *time.Time
+ if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
+ resetAt = calculated
+ } else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
+ t := time.Unix(*unixTs, 0)
+ resetAt = &t
+ }
+ if resetAt == nil {
+ return
+ }
+
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
+ return
+ }
+
+ now := time.Now()
+ account.RateLimitedAt = &now
+ account.RateLimitResetAt = resetAt
+
+ if account.Status == StatusError {
+ if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
+ return
+ }
+ account.Status = StatusActive
+ account.ErrorMessage = ""
+ }
+}
+
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 82ff0a8b..213ef52c 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -61,9 +61,10 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
- updatedExtra map[string]any
+ updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
+ clearedErrorID int64
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -77,6 +78,11 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
+func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
+ r.clearedErrorID = id
+ return nil
+}
+
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
@@ -111,11 +117,11 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
-func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
+func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
- resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
@@ -130,6 +136,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
+ Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
@@ -138,7 +145,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
- require.Zero(t, repo.rateLimitedID)
- require.Nil(t, repo.rateLimitedAt)
- require.Nil(t, account.RateLimitResetAt)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
}
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 4730303f..9344de47 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
-func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
+func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
@@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
+func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
+ return calculateOpenAI429ResetTime(headers)
+}
+
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
From 5fc30ea964c47bac64b4c3dbde5e4c4180eb5078 Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Tue, 21 Apr 2026 09:03:25 +0800
Subject: [PATCH 02/33] test: cover openai admin test state transitions
---
.../account_test_service_openai_test.go | 130 +++++++++++++++++-
1 file changed, 127 insertions(+), 3 deletions(-)
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 213ef52c..12d8128a 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -61,10 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
- updatedExtra map[string]any
- rateLimitedID int64
- rateLimitedAt *time.Time
+ updatedExtra map[string]any
+ rateLimitedID int64
+ rateLimitedAt *time.Time
clearedErrorID int64
+ setErrorID int64
+ setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -83,6 +85,12 @@ func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
return nil
}
+func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
+ r.setErrorID = id
+ r.setErrorMsg = errorMsg
+ return nil
+}
+
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
@@ -148,4 +156,120 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testin
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
+ require.NotNil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 77,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
+ require.NotNil(t, account.RateLimitResetAt)
+ require.Empty(t, repo.updatedExtra)
+}
+
+func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 78,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.NotNil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 79,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "stale 403",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ require.Error(t, err)
+ require.Zero(t, repo.rateLimitedID)
+ require.Nil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusError, account.Status)
+ require.Equal(t, "stale 403", account.ErrorMessage)
+ require.Nil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 80,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.setErrorID)
+ require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
+ require.Zero(t, repo.rateLimitedID)
+ require.Zero(t, repo.clearedErrorID)
+ require.Nil(t, account.RateLimitResetAt)
}
From d80469ea35d6dc085e9e4aee43834280c3a18b5d Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Thu, 23 Apr 2026 18:15:00 +0800
Subject: [PATCH 03/33] test: fix OpenAI account test helper calls after rebase
---
.../internal/service/account_test_service_openai_test.go | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 12d8128a..c1e42b4f 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -180,7 +180,7 @@ func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleErro
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
@@ -209,7 +209,7 @@ func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
@@ -237,7 +237,7 @@ func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
@@ -265,7 +265,7 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "")
require.Error(t, err)
require.Equal(t, account.ID, repo.setErrorID)
require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
From f3ea878ba21297f9237142d20c6a83f698cbabbb Mon Sep 17 00:00:00 2001
From: KnowSky404
Date: Thu, 23 Apr 2026 18:33:27 +0800
Subject: [PATCH 04/33] chore: trigger PR checks
From c4d496da18db6885b7d22d449451b1dbf518b73d Mon Sep 17 00:00:00 2001
From: gaoren002
Date: Fri, 24 Apr 2026 07:42:31 +0000
Subject: [PATCH 05/33] fix(openai): handle codex spark model limitations
---
.../service/openai_codex_transform.go | 66 +++++++++++
.../service/openai_codex_transform_test.go | 111 ++++++++++++++++++
.../service/openai_gateway_service.go | 11 ++
.../internal/service/openai_model_mapping.go | 24 +++-
.../service/openai_model_mapping_test.go | 47 +++++++-
5 files changed, 257 insertions(+), 2 deletions(-)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 14abde9b..f9c0de72 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -48,6 +48,8 @@ type codexTransformResult struct {
const (
codexImageGenerationBridgeMarker = ""
codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n"
+ codexSparkImageUnsupportedMarker = ""
+ codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n"
)
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
@@ -165,6 +167,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if applyInstructions(reqBody, isCodexCLI) {
result.Modified = true
}
+ if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
+ result.Modified = true
+ }
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
@@ -244,6 +249,10 @@ func normalizeCodexModel(model string) string {
return "gpt-5.4"
}
+func isCodexSparkModel(model string) bool {
+ return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
+}
+
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -265,6 +274,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
return false
}
+func hasOpenAIInputImage(reqBody map[string]any) bool {
+ if reqBody == nil {
+ return false
+ }
+ return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
+}
+
+func hasOpenAIInputImageValue(value any) bool {
+ switch v := value.(type) {
+ case []any:
+ for _, item := range v {
+ if hasOpenAIInputImageValue(item) {
+ return true
+ }
+ }
+ case map[string]any:
+ if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
+ return true
+ }
+ if _, ok := v["image_url"]; ok {
+ return true
+ }
+ return hasOpenAIInputImageValue(v["content"])
+ }
+ return false
+}
+
+func validateCodexSparkInput(reqBody map[string]any, model string) error {
+ if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
+ return nil
+ }
+ return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
+}
+
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
@@ -309,6 +352,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
+ if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
+ return false
+ }
tool := map[string]any{
"type": "image_generation",
@@ -344,6 +390,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
return false
}
+ if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
+ return false
+ }
existing, _ := reqBody["instructions"].(string)
if strings.Contains(existing, codexImageGenerationBridgeMarker) {
@@ -360,6 +409,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
return true
}
+func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
+ if len(reqBody) == 0 {
+ return false
+ }
+ existing, _ := reqBody["instructions"].(string)
+ if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
+ return false
+ }
+ existing = strings.TrimRight(existing, " \t\r\n")
+ if strings.TrimSpace(existing) == "" {
+ reqBody["instructions"] = codexSparkImageUnsupportedText
+ return true
+ }
+ reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
+ return true
+}
+
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
if !hasOpenAIImageGenerationTool(reqBody) {
return nil
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 4fd16fdb..3a965d5c 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -261,6 +261,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
require.Equal(t, "png", tool["output_format"])
}
+func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": "draw a cat",
+ }
+
+ modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
+ require.False(t, modified)
+ require.NotContains(t, reqBody, "tools")
+}
+
func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
@@ -306,6 +317,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *
func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
reqBody := map[string]any{
+ "model": "gpt-5.4",
"instructions": "existing instructions",
"tools": []any{
map[string]any{"type": "image_generation", "output_format": "png"},
@@ -325,6 +337,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin
require.False(t, modified)
}
+func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "instructions": "existing instructions",
+ "tools": []any{
+ map[string]any{"type": "image_generation", "output_format": "png"},
+ },
+ }
+
+ modified := applyCodexImageGenerationBridgeInstructions(reqBody)
+ require.False(t, modified)
+ require.Equal(t, "existing instructions", reqBody["instructions"])
+}
+
func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
reqBody := map[string]any{
"instructions": "existing instructions",
@@ -338,6 +364,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te
require.Equal(t, "existing instructions", reqBody["instructions"])
}
+func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": "describe"},
+ map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
+ },
+ },
+ },
+ }
+
+ err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "does not support image input")
+}
+
+func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "messages": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "text", "text": "describe"},
+ map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
+ },
+ },
+ },
+ }
+
+ err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
+ require.Error(t, err)
+}
+
+func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": "hello"},
+ },
+ },
+ },
+ }
+
+ require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
+}
+
+func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "instructions": "existing instructions",
+ "input": "hello",
+ }
+
+ result := applyCodexOAuthTransform(reqBody, true, false)
+ require.True(t, result.Modified)
+
+ instructions, ok := reqBody["instructions"].(string)
+ require.True(t, ok)
+ require.Contains(t, instructions, "existing instructions")
+ require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
+ require.Contains(t, instructions, "does not support image generation")
+ require.Contains(t, instructions, "switch to a non-Spark Codex model")
+ require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
+}
+
+func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "instructions": "existing instructions",
+ "input": "hello",
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+ instructions, ok := reqBody["instructions"].(string)
+ require.True(t, ok)
+ require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
+}
+
func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-image-2",
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index d99cd7da..2d05c3ea 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -1995,6 +1995,17 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
account.Type,
)
}
+ if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil {
+ setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": err.Error(),
+ "param": "input",
+ },
+ })
+ return nil, err
+ }
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go
index 9bf3fba3..993c0b13 100644
--- a/backend/internal/service/openai_model_mapping.go
+++ b/backend/internal/service/openai_model_mapping.go
@@ -1,5 +1,7 @@
package service
+import "strings"
+
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
@@ -12,8 +14,28 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
- if !matched && defaultMappedModel != "" {
+ if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
return defaultMappedModel
}
return mappedModel
}
+
+func isExplicitCodexModel(model string) bool {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return false
+ }
+ if strings.Contains(model, "/") {
+ parts := strings.Split(model, "/")
+ model = parts[len(parts)-1]
+ }
+ model = strings.ToLower(strings.TrimSpace(model))
+ if getNormalizedCodexModel(model) != "" {
+ return true
+ }
+ if strings.HasSuffix(model, "-openai-compact") {
+ base := strings.TrimSuffix(model, "-openai-compact")
+ return getNormalizedCodexModel(base) != ""
+ }
+ return false
+}
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index f25863a8..21a2e9a0 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
account: &Account{
Credentials: map[string]any{},
},
- requestedModel: "gpt-5.4",
+ requestedModel: "claude-opus-4-6",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
+ {
+ name: "preserves explicit gpt-5.4 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.4",
+ defaultMappedModel: "gpt-4o-mini",
+ expectedModel: "gpt-5.4",
+ },
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
+ {
+ name: "preserves codex spark instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.3-codex-spark",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.3-codex-spark",
+ },
+ {
+ name: "preserves gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.5",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.5",
+ },
+ {
+ name: "preserves openai namespaced gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "openai/gpt-5.5",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "openai/gpt-5.5",
+ },
+ {
+ name: "preserves compact gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.5-openai-compact",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.5-openai-compact",
+ },
}
for _, tt := range tests {
From 959af1c8f6fb120863242d2be5b62b5820b8af0e Mon Sep 17 00:00:00 2001
From: song
Date: Fri, 24 Apr 2026 17:15:42 +0800
Subject: [PATCH 06/33] fix(openai): preserve codex tool call ids
---
.../service/openai_codex_transform.go | 22 ++++--
.../service/openai_codex_transform_test.go | 72 +++++++++++++++++++
.../service/openai_tool_continuation.go | 4 +-
.../service/openai_tool_continuation_test.go | 3 +
4 files changed, 94 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 14abde9b..65f7f5b4 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -658,12 +658,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
}
+ if !isCodexToolCallItemType(typ) {
+ ensureCopy()
+ delete(newItem, "call_id")
+ }
+
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
- if !isCodexToolCallItemType(typ) {
- delete(newItem, "call_id")
- }
}
filtered = append(filtered, newItem)
@@ -672,10 +674,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
func isCodexToolCallItemType(typ string) bool {
- if typ == "" {
+ switch typ {
+ case "function_call",
+ "tool_call",
+ "local_shell_call",
+ "tool_search_call",
+ "custom_tool_call",
+ "function_call_output",
+ "mcp_tool_call_output",
+ "custom_tool_call_output",
+ "tool_search_output":
+ return true
+ default:
return false
}
- return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 4fd16fdb..476f1ea9 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -92,6 +92,78 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
require.Equal(t, "fc1", second["call_id"])
}
+func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "tool_search_output", first["type"])
+ require.Equal(t, "fc1", first["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
+ map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "fccustom", first["call_id"])
+
+ second, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "fcmcp", second["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
+ map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
+ },
+ "tool_choice": "auto",
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "ig_123", first["id"])
+ _, hasCallID := first["call_id"]
+ require.False(t, hasCallID)
+
+ second, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ _, hasCallID = second["call_id"]
+ require.False(t, hasCallID)
+}
+
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go
index dea3c172..c0f98de4 100644
--- a/backend/internal/service/openai_tool_continuation.go
+++ b/backend/internal/service/openai_tool_continuation.go
@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
-// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
+// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
continue
}
itemType, _ := itemMap["type"].(string)
- if itemType == "function_call_output" || itemType == "item_reference" {
+ if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
return true
}
}
diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go
index fe737ad6..3f415d9d 100644
--- a/backend/internal/service/openai_tool_continuation_test.go
+++ b/backend/internal/service/openai_tool_continuation_test.go
@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
+ {name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
+ {name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
+ {name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
From e65574dea99d42da4681a52f10e7f9bdc5deed99 Mon Sep 17 00:00:00 2001
From: gaoren002
Date: Fri, 24 Apr 2026 12:03:19 +0000
Subject: [PATCH 07/33] fix(openai): normalize codex responses payloads
---
.../service/openai_codex_transform.go | 214 ++++++++++++++++++
.../service/openai_codex_transform_test.go | 125 ++++++++++
2 files changed, 339 insertions(+)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 20d303b3..4903c420 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -1,6 +1,7 @@
package service
import (
+ "encoding/json"
"fmt"
"strings"
)
@@ -153,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if normalizeCodexTools(reqBody) {
result.Modified = true
}
+ if normalizeCodexToolChoice(reqBody) {
+ result.Modified = true
+ }
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
@@ -173,6 +177,14 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
+ if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
+ input = normalizedInput
+ result.Modified = true
+ }
+ if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
+ input = normalizedInput
+ result.Modified = true
+ }
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
@@ -197,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
return result
}
+func normalizeCodexToolChoice(reqBody map[string]any) bool {
+ choice, ok := reqBody["tool_choice"]
+ if !ok || choice == nil {
+ return false
+ }
+ choiceMap, ok := choice.(map[string]any)
+ if !ok {
+ return false
+ }
+ choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
+ if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
+ return false
+ }
+ reqBody["tool_choice"] = "auto"
+ return true
+}
+
+func codexToolsContainType(rawTools any, toolType string) bool {
+ tools, ok := rawTools.([]any)
+ if !ok || strings.TrimSpace(toolType) == "" {
+ return false
+ }
+ for _, rawTool := range tools {
+ tool, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
+ if len(input) == 0 {
+ return input, false
+ }
+
+ modified := false
+ normalized := make([]any, 0, len(input))
+ for _, item := range input {
+ m, ok := item.(map[string]any)
+ if !ok {
+ normalized = append(normalized, item)
+ continue
+ }
+ role, _ := m["role"].(string)
+ if strings.TrimSpace(role) != "tool" {
+ normalized = append(normalized, item)
+ continue
+ }
+
+ callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
+ callID = strings.TrimSpace(callID)
+ if callID == "" {
+ // Responses does not accept role:"tool". If no call id is available,
+ // preserve the text as a user message instead of sending invalid input.
+ fallback := make(map[string]any, len(m))
+ for key, value := range m {
+ fallback[key] = value
+ }
+ fallback["role"] = "user"
+ delete(fallback, "tool_call_id")
+ normalized = append(normalized, fallback)
+ modified = true
+ continue
+ }
+
+ output := extractTextFromContent(m["content"])
+ if output == "" {
+ if value, ok := m["output"].(string); ok {
+ output = value
+ }
+ }
+ if output == "" && m["content"] != nil {
+ if b, err := json.Marshal(m["content"]); err == nil {
+ output = string(b)
+ }
+ }
+
+ normalized = append(normalized, map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ "output": output,
+ })
+ modified = true
+ }
+ if !modified {
+ return input, false
+ }
+ return normalized, true
+}
+
+func normalizeCodexMessageContentText(input []any) ([]any, bool) {
+ if len(input) == 0 {
+ return input, false
+ }
+
+ modified := false
+ normalized := make([]any, 0, len(input))
+ for _, item := range input {
+ m, ok := item.(map[string]any)
+ if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
+ normalized = append(normalized, item)
+ continue
+ }
+ parts, ok := m["content"].([]any)
+ if !ok {
+ normalized = append(normalized, item)
+ continue
+ }
+
+ var newItem map[string]any
+ var newParts []any
+ ensureItemCopy := func() {
+ if newItem != nil {
+ return
+ }
+ newItem = make(map[string]any, len(m))
+ for key, value := range m {
+ newItem[key] = value
+ }
+ newParts = make([]any, len(parts))
+ copy(newParts, parts)
+ }
+
+ for i, rawPart := range parts {
+ part, ok := rawPart.(map[string]any)
+ if !ok {
+ continue
+ }
+ text, hasText := part["text"]
+ if !hasText {
+ continue
+ }
+ if _, ok := text.(string); ok {
+ continue
+ }
+
+ ensureItemCopy()
+ newPart := make(map[string]any, len(part))
+ for key, value := range part {
+ newPart[key] = value
+ }
+ newPart["text"] = stringifyCodexContentText(text)
+ newParts[i] = newPart
+ modified = true
+ }
+
+ if newItem != nil {
+ newItem["content"] = newParts
+ normalized = append(normalized, newItem)
+ continue
+ }
+ normalized = append(normalized, item)
+ }
+ if !modified {
+ return input, false
+ }
+ return normalized, true
+}
+
+func stringifyCodexContentText(value any) string {
+ switch v := value.(type) {
+ case string:
+ return v
+ case nil:
+ return ""
+ default:
+ if b, err := json.Marshal(v); err == nil {
+ return string(b)
+ }
+ return fmt.Sprint(v)
+ }
+}
+
func normalizeCodexModel(model string) string {
model = strings.TrimSpace(model)
if model == "" {
@@ -729,6 +918,22 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
delete(newItem, "call_id")
}
+ if codexInputItemRequiresName(typ) {
+ if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
+ name := firstNonEmptyString(m["tool_name"])
+ if name == "" {
+ if function, ok := m["function"].(map[string]any); ok {
+ name = firstNonEmptyString(function["name"])
+ }
+ }
+ if name == "" {
+ name = "tool"
+ }
+ ensureCopy()
+ newItem["name"] = name
+ }
+ }
+
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
@@ -756,6 +961,15 @@ func isCodexToolCallItemType(typ string) bool {
}
}
+func codexInputItemRequiresName(typ string) bool {
+ switch strings.TrimSpace(typ) {
+ case "function_call", "custom_tool_call", "mcp_tool_call":
+ return true
+ default:
+ return false
+ }
+}
+
func normalizeCodexTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index f655e61c..ca9b4cea 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -164,6 +164,131 @@ func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testi
require.False(t, hasCallID)
}
+func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "message",
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "ok",
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function_call_output", item["type"])
+ require.Equal(t, "fc1", item["call_id"])
+ require.Equal(t, "ok", item["output"])
+ _, hasRole := item["role"]
+ require.False(t, hasRole)
+}
+
+func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": []any{"a", "b"}},
+ },
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ content, ok := item["content"].([]any)
+ require.True(t, ok)
+ part, ok := content[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, `["a","b"]`, part["text"])
+}
+
+func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{"type": "custom"},
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ require.Equal(t, "auto", reqBody["tool_choice"])
+}
+
+func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "custom", "name": "shell"},
+ },
+ "tool_choice": map[string]any{"type": "custom"},
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ choice, ok := reqBody["tool_choice"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "custom", choice["type"])
+}
+
+func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{"type": "message", "role": "user", "content": "run tool"},
+ map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+ item, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function_call", item["type"])
+ require.Equal(t, "tool", item["name"])
+ require.Equal(t, "fc1", item["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "shell", item["name"])
+ require.Equal(t, "fc1", item["call_id"])
+}
+
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
From 27ee141c1edef7682fb19e3cf14b9433a592998e Mon Sep 17 00:00:00 2001
From: gaoren002
Date: Fri, 24 Apr 2026 13:24:21 +0000
Subject: [PATCH 08/33] fix(openai): preserve mcp tool call ids
---
.../service/openai_codex_transform.go | 1 +
.../service/openai_codex_transform_test.go | 32 +++++++++++++++++++
2 files changed, 33 insertions(+)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 4903c420..e765d7e9 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -951,6 +951,7 @@ func isCodexToolCallItemType(typ string) bool {
"local_shell_call",
"tool_search_call",
"custom_tool_call",
+ "mcp_tool_call",
"function_call_output",
"mcp_tool_call_output",
"custom_tool_call_output",
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index ca9b4cea..75f5c55c 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -289,6 +289,38 @@ func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
require.Equal(t, "fc1", item["call_id"])
}
+func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "mcp_tool_call",
+ "call_id": "call_abc",
+ "name": "remote_tool",
+ "arguments": "{}",
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "mcp_tool_call", item["type"])
+ require.Equal(t, "remote_tool", item["name"])
+ require.Equal(t, "fcabc", item["call_id"])
+}
+
+func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
+ for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
+ require.True(t, codexInputItemRequiresName(typ), typ)
+ require.True(t, isCodexToolCallItemType(typ), typ)
+ }
+}
+
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
From f03de00cb9a8ead97cdf4507650a746eeb66a198 Mon Sep 17 00:00:00 2001
From: VpSanta33 <1398474964@qq.com>
Date: Fri, 24 Apr 2026 21:41:26 +0800
Subject: [PATCH 09/33] feat: add affiliate invite rebate flow and admin
rebate-rate setting
---
backend/cmd/server/wire_gen.go | 28 +-
.../internal/handler/admin/setting_handler.go | 17 +
backend/internal/handler/auth_handler.go | 11 +-
backend/internal/handler/dto/settings.go | 1 +
backend/internal/handler/user_handler.go | 74 ++-
backend/internal/handler/wire.go | 14 +-
backend/internal/repository/affiliate_repo.go | 420 ++++++++++++++++++
.../affiliate_repo_integration_test.go | 114 +++++
backend/internal/repository/wire.go | 1 +
backend/internal/server/api_contract_test.go | 2 +
backend/internal/server/routes/user.go | 2 +
backend/internal/service/affiliate_service.go | 288 ++++++++++++
.../service/affiliate_service_test.go | 59 +++
backend/internal/service/auth_service.go | 29 +-
backend/internal/service/domain_constants.go | 8 +
.../internal/service/payment_fulfillment.go | 136 ++++++
backend/internal/service/payment_service.go | 30 +-
backend/internal/service/setting_service.go | 22 +
backend/internal/service/settings_view.go | 1 +
backend/internal/service/wire.go | 52 ++-
.../migrations/130_add_user_affiliates.sql | 20 +
.../131_affiliate_rebate_hardening.sql | 58 +++
frontend/src/api/admin/settings.ts | 2 +
frontend/src/api/user.ts | 23 +-
frontend/src/components/layout/AppSidebar.vue | 1 +
frontend/src/i18n/locales/en.ts | 45 ++
frontend/src/i18n/locales/zh.ts | 44 ++
frontend/src/router/index.ts | 12 +
frontend/src/types/index.ts | 23 +
frontend/src/views/admin/SettingsView.vue | 30 ++
frontend/src/views/auth/EmailVerifyView.vue | 5 +-
frontend/src/views/auth/RegisterView.vue | 13 +-
frontend/src/views/user/AffiliateView.vue | 201 +++++++++
33 files changed, 1744 insertions(+), 42 deletions(-)
create mode 100644 backend/internal/repository/affiliate_repo.go
create mode 100644 backend/internal/repository/affiliate_repo_integration_test.go
create mode 100644 backend/internal/service/affiliate_service.go
create mode 100644 backend/internal/service/affiliate_service_test.go
create mode 100644 backend/migrations/130_add_user_affiliates.sql
create mode 100644 backend/migrations/131_affiliate_rebate_hardening.sql
create mode 100644 frontend/src/views/user/AffiliateView.vue
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 93270e7e..f8e0dcf4 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
- authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
+ affiliateRepository := repository.NewAffiliateRepository(client, db)
+ affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
+ authService := service.ProvideAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache)
+ userHandler := handler.ProvideUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
+ channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
+ channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
+ channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -192,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
- paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
+ paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
@@ -221,20 +226,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
- sqlDB, err := repository.ProvideSQLDB(client)
- if err != nil {
- return nil, err
- }
- channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB)
- channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB)
+ channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
+ channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
- channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
- channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
- channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
- channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
@@ -245,9 +241,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
+ availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -263,6 +260,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 4277f0f1..2d4dcb5b 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
+ AffiliateRebateRate: settings.AffiliateRebateRate,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
@@ -338,6 +339,7 @@ type UpdateSettingsRequest struct {
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
@@ -468,6 +470,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
+ affiliateRebateRate := previousSettings.AffiliateRebateRate
+ if req.AffiliateRebateRate != nil {
+ affiliateRebateRate = *req.AffiliateRebateRate
+ }
+ if affiliateRebateRate < service.AffiliateRebateRateMin {
+ affiliateRebateRate = service.AffiliateRebateRateMin
+ }
+ if affiliateRebateRate > service.AffiliateRebateRateMax {
+ affiliateRebateRate = service.AffiliateRebateRateMax
+ }
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -1119,6 +1131,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
+ AffiliateRebateRate: affiliateRebateRate,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
@@ -1433,6 +1446,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
+ AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -1738,6 +1752,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
+ if before.AffiliateRebateRate != after.AffiliateRebateRate {
+ changed = append(changed, "affiliate_rebate_rate")
+ }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index dc68a466..1f9a66ff 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -48,6 +48,7 @@ type RegisterRequest struct {
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
+ AffCode string `json:"aff_code"` // 邀请返利码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
- _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
+ _, user, err := h.authService.RegisterWithVerification(
+ c.Request.Context(),
+ req.Email,
+ req.Password,
+ req.VerifyCode,
+ req.PromoCode,
+ req.InvitationCode,
+ req.AffCode,
+ )
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 2affbc46..86074df7 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index f74c2b72..c386792c 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -5,6 +5,7 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -14,10 +15,11 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
- userService *service.UserService
- authService *service.AuthService
- emailService *service.EmailService
- emailCache service.EmailCache
+ userService *service.UserService
+ authService *service.AuthService
+ emailService *service.EmailService
+ emailCache service.EmailCache
+ affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
@@ -35,6 +37,13 @@ func NewUserHandler(
}
}
+func (h *UserHandler) SetAffiliateService(affiliateService *service.AffiliateService) {
+ if h == nil {
+ return
+ }
+ h.affiliateService = affiliateService
+}
+
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
@@ -159,6 +168,63 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, profileResp)
}
+func (h *UserHandler) affiliateServiceOrErr() (*service.AffiliateService, error) {
+ if h == nil || h.affiliateService == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return h.affiliateService, nil
+}
+
+// GetAffiliate returns the current user's affiliate details.
+// GET /api/v1/user/aff
+func (h *UserHandler) GetAffiliate(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ affiliateSvc, err := h.affiliateServiceOrErr()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ detail, err := affiliateSvc.GetAffiliateDetail(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, detail)
+}
+
+// TransferAffiliateQuota transfers all available affiliate quota into current balance.
+// POST /api/v1/user/aff/transfer
+func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ affiliateSvc, err := h.affiliateServiceOrErr()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ transferred, balance, err := affiliateSvc.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "transferred_quota": transferred,
+ "balance": balance,
+ })
+}
+
type StartIdentityBindingRequest struct {
Provider string `json:"provider" binding:"required"`
RedirectTo string `json:"redirect_to"`
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 6d175488..d4b34fd2 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -80,6 +80,18 @@ func ProvideSettingHandler(settingService *service.SettingService, buildInfo Bui
return NewSettingHandler(settingService, buildInfo.Version)
}
+func ProvideUserHandler(
+ userService *service.UserService,
+ authService *service.AuthService,
+ emailService *service.EmailService,
+ emailCache service.EmailCache,
+ affiliateService *service.AffiliateService,
+) *UserHandler {
+ handler := NewUserHandler(userService, authService, emailService, emailCache)
+ handler.SetAffiliateService(affiliateService)
+ return handler
+}
+
// ProvideHandlers creates the Handlers struct
func ProvideHandlers(
authHandler *AuthHandler,
@@ -125,7 +137,7 @@ func ProvideHandlers(
var ProviderSet = wire.NewSet(
// Top-level handlers
NewAuthHandler,
- NewUserHandler,
+ ProvideUserHandler,
NewAPIKeyHandler,
NewUsageHandler,
NewRedeemHandler,
diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go
new file mode 100644
index 00000000..342ddf4f
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo.go
@@ -0,0 +1,420 @@
+package repository
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+const (
+ affiliateCodeLength = 12
+ affiliateCodeMaxAttempts = 12
+)
+
+var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
+
+type affiliateQueryExecer interface {
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+}
+
+type affiliateRepository struct {
+ client *dbent.Client
+}
+
+func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
+ return &affiliateRepository{client: client}
+}
+
+func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, service.ErrUserNotFound
+ }
+ client := clientFromContext(ctx, r.client)
+ return ensureUserAffiliateWithClient(ctx, client, userID)
+}
+
+func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
+ client := clientFromContext(ctx, r.client)
+ return queryAffiliateByCode(ctx, client, code)
+}
+
+func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
+ var bound bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
+ return err
+ }
+
+ res, err := txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
+ inviterID, userID,
+ )
+ if err != nil {
+ return fmt.Errorf("bind inviter: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ bound = false
+ return nil
+ }
+
+ if _, err = txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
+ inviterID,
+ ); err != nil {
+ return fmt.Errorf("increment inviter aff_count: %w", err)
+ }
+ bound = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return bound, nil
+}
+
+func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
+ if amount <= 0 {
+ return false, nil
+ }
+
+ var applied bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ res, err := txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
+ amount, inviterID,
+ )
+ if err != nil {
+ return err
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ applied = false
+ return nil
+ }
+
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+
+ applied = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return applied, nil
+}
+
+func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
+ var transferred float64
+ var newBalance float64
+
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+
+ rows, err := txClient.QueryContext(txCtx, `
+WITH claimed AS (
+ SELECT aff_quota::double precision AS amount
+ FROM user_affiliates
+ WHERE user_id = $1
+ AND aff_quota > 0
+ FOR UPDATE
+),
+cleared AS (
+ UPDATE user_affiliates ua
+ SET aff_quota = 0,
+ updated_at = NOW()
+ FROM claimed c
+ WHERE ua.user_id = $1
+ RETURNING c.amount
+)
+SELECT amount
+FROM cleared`, userID)
+ if err != nil {
+ return fmt.Errorf("claim affiliate quota: %w", err)
+ }
+
+ if !rows.Next() {
+ _ = rows.Close()
+ if err := rows.Err(); err != nil {
+ return err
+ }
+ return service.ErrAffiliateQuotaEmpty
+ }
+ if err := rows.Scan(&transferred); err != nil {
+ _ = rows.Close()
+ return err
+ }
+ if err := rows.Close(); err != nil {
+ return err
+ }
+ if transferred <= 0 {
+ return service.ErrAffiliateQuotaEmpty
+ }
+
+ affected, err := txClient.User.Update().
+ Where(user.IDEQ(userID)).
+ AddBalance(transferred).
+ AddTotalRecharged(transferred).
+ Save(txCtx)
+ if err != nil {
+ return fmt.Errorf("credit user balance by affiliate quota: %w", err)
+ }
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+
+ newBalance, err = queryUserBalance(txCtx, txClient, userID)
+ if err != nil {
+ return err
+ }
+
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
+ return fmt.Errorf("insert affiliate transfer ledger: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, 0, err
+ }
+
+ return transferred, newBalance, nil
+}
+
+func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
+ if limit <= 0 {
+ limit = 100
+ }
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx, `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.created_at
+FROM user_affiliates ua
+LEFT JOIN users u ON u.id = ua.user_id
+WHERE ua.inviter_id = $1
+ORDER BY ua.created_at DESC
+LIMIT $2`, inviterID, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ invitees := make([]service.AffiliateInvitee, 0)
+ for rows.Next() {
+ var item service.AffiliateInvitee
+ var createdAt time.Time
+ if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
+ return nil, err
+ }
+ item.CreatedAt = &createdAt
+ invitees = append(invitees, item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return invitees, nil
+}
+
+func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return fn(ctx, tx.Client())
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin affiliate transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx, tx.Client()); err != nil {
+ return err
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit affiliate transaction: %w", err)
+ }
+ return nil
+}
+
+func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ summary, err := queryAffiliateByUserID(ctx, client, userID)
+ if err == nil {
+ return summary, nil
+ }
+ if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
+ return nil, err
+ }
+
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ code, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return nil, codeErr
+ }
+ _, insertErr := client.ExecContext(ctx, `
+INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
+VALUES ($1, $2, NOW(), NOW())
+ON CONFLICT (user_id) DO NOTHING`, userID, code)
+ if insertErr == nil {
+ break
+ }
+ if isAffiliateUniqueViolation(insertErr) {
+ continue
+ }
+ return nil, insertErr
+ }
+
+ return queryAffiliateByUserID(ctx, client, userID)
+}
+
+func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ return &out, nil
+}
+
+func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE aff_code = $1
+LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ return &out, nil
+}
+
+func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
+ rows, err := client.QueryContext(ctx,
+ "SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
+ userID,
+ )
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, service.ErrUserNotFound
+ }
+ var balance float64
+ if err := rows.Scan(&balance); err != nil {
+ return 0, err
+ }
+ return balance, nil
+}
+
+func generateAffiliateCode() (string, error) {
+ buf := make([]byte, affiliateCodeLength)
+ if _, err := rand.Read(buf); err != nil {
+ return "", fmt.Errorf("generate affiliate code: %w", err)
+ }
+ for i := range buf {
+ buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
+ }
+ return string(buf), nil
+}
+
+func isAffiliateUniqueViolation(err error) bool {
+ var pqErr *pq.Error
+ if errors.As(err, &pqErr) {
+ return string(pqErr.Code) == "23505"
+ }
+ return false
+}
diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go
new file mode 100644
index 00000000..3ab5c0fb
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo_integration_test.go
@@ -0,0 +1,114 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value float64
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value int
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 5.5,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.NoError(t, err)
+ require.InDelta(t, 12.34, transferred, 1e-9)
+ require.InDelta(t, 17.84, balance, 1e-9)
+
+ affQuota := querySingleFloat(t, txCtx, client,
+ "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
+ require.InDelta(t, 0.0, affQuota, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 17.84, persistedBalance, 1e-9)
+
+ ledgerCount := querySingleInt(t, txCtx, client,
+ "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
+ require.Equal(t, 1, ledgerCount)
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 3.21,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
+ require.InDelta(t, 0.0, transferred, 1e-9)
+ require.InDelta(t, 0.0, balance, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 3.21, persistedBalance, 1e-9)
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 6d24d312..f07bbb33 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
+ NewAffiliateRepository,
// Cache implementations
NewGatewayCache,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index e89ef3d9..35a6524a 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -715,6 +715,7 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
+ "affiliate_rebate_rate": 20,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -895,6 +896,7 @@ func TestAPIContracts(t *testing.T) {
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
+ "affiliate_rebate_rate": 20,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index babab125..9976954c 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -25,6 +25,8 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ user.GET("/aff", h.User.GetAffiliate)
+ user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go
new file mode 100644
index 00000000..6fa5b423
--- /dev/null
+++ b/backend/internal/service/affiliate_service.go
@@ -0,0 +1,288 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+var (
+ ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
+ ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
+ ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
+ ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
+)
+
+const (
+ affiliateInviteesLimit = 100
+)
+
+type AffiliateSummary struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+type AffiliateInvitee struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ CreatedAt *time.Time `json:"created_at,omitempty"`
+}
+
+type AffiliateDetail struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ Invitees []AffiliateInvitee `json:"invitees"`
+}
+
+type AffiliateRepository interface {
+ EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
+ GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
+ BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
+ AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
+ TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
+ ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
+}
+
+type AffiliateService struct {
+ repo AffiliateRepository
+ settingRepo SettingRepository
+ authCacheInvalidator APIKeyAuthCacheInvalidator
+ billingCacheService *BillingCacheService
+}
+
+func NewAffiliateService(repo AffiliateRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
+ return &AffiliateService{
+ repo: repo,
+ settingRepo: settingRepo,
+ authCacheInvalidator: authCacheInvalidator,
+ billingCacheService: billingCacheService,
+ }
+}
+
+func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
+ }
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.EnsureUserAffiliate(ctx, userID)
+}
+
+func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
+ summary, err := s.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ invitees, err := s.listInvitees(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ return &AffiliateDetail{
+ UserID: summary.UserID,
+ AffCode: summary.AffCode,
+ InviterID: summary.InviterID,
+ AffCount: summary.AffCount,
+ AffQuota: summary.AffQuota,
+ AffHistoryQuota: summary.AffHistoryQuota,
+ Invitees: invitees,
+ }, nil
+}
+
+func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if code == "" {
+ return nil
+ }
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+
+ selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return err
+ }
+ if selfSummary.InviterID != nil {
+ return nil
+ }
+
+ inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
+ if err != nil {
+ if errors.Is(err, ErrAffiliateProfileNotFound) {
+ return ErrAffiliateCodeInvalid
+ }
+ return err
+ }
+ if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
+ return ErrAffiliateCodeInvalid
+ }
+
+ bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
+ if err != nil {
+ return err
+ }
+ if !bound {
+ return ErrAffiliateAlreadyBound
+ }
+ return nil
+}
+
+func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, nil
+ }
+ if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
+ return 0, nil
+ }
+
+ inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
+ return 0, nil
+ }
+
+ rebateRatePercent := s.loadAffiliateRebateRatePercent(ctx)
+ rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
+ if rebate <= 0 {
+ return 0, nil
+ }
+
+ if _, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID); err != nil {
+ return 0, err
+ }
+
+ applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
+ if err != nil {
+ return 0, err
+ }
+ if !applied {
+ return 0, nil
+ }
+ return rebate, nil
+}
+
+func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+
+ transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
+ if err != nil {
+ return 0, 0, err
+ }
+ if transferred > 0 {
+ s.invalidateAffiliateCaches(ctx, userID)
+ }
+ return transferred, balance, nil
+}
+
+func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
+ if err != nil {
+ return nil, err
+ }
+ for i := range invitees {
+ invitees[i].Email = maskEmail(invitees[i].Email)
+ }
+ return invitees, nil
+}
+
+func (s *AffiliateService) loadAffiliateRebateRatePercent(ctx context.Context) float64 {
+ if s == nil || s.settingRepo == nil {
+ return AffiliateRebateRateDefault
+ }
+
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
+ if err != nil {
+ return AffiliateRebateRateDefault
+ }
+
+ rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
+ if err != nil {
+ return AffiliateRebateRateDefault
+ }
+ if math.IsNaN(rate) || math.IsInf(rate, 0) {
+ return AffiliateRebateRateDefault
+ }
+ if rate < AffiliateRebateRateMin {
+ return AffiliateRebateRateMin
+ }
+ if rate > AffiliateRebateRateMax {
+ return AffiliateRebateRateMax
+ }
+ return rate
+}
+
+func roundTo(v float64, scale int) float64 {
+ factor := math.Pow10(scale)
+ return math.Round(v*factor) / factor
+}
+
+func maskEmail(email string) string {
+ email = strings.TrimSpace(email)
+ if email == "" {
+ return ""
+ }
+ at := strings.Index(email, "@")
+ if at <= 0 || at >= len(email)-1 {
+ return "***"
+ }
+
+ local := email[:at]
+ domain := email[at+1:]
+ dot := strings.LastIndex(domain, ".")
+
+ maskedLocal := maskSegment(local)
+ if dot <= 0 || dot >= len(domain)-1 {
+ return maskedLocal + "@" + maskSegment(domain)
+ }
+
+ domainName := domain[:dot]
+ tld := domain[dot:]
+ return maskedLocal + "@" + maskSegment(domainName) + tld
+}
+
+func maskSegment(s string) string {
+ r := []rune(s)
+ if len(r) == 0 {
+ return "***"
+ }
+ if len(r) == 1 {
+ return string(r[0]) + "***"
+ }
+ return string(r[0]) + "***"
+}
+
+func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ if s.billingCacheService != nil {
+ go func() {
+ cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
+ }()
+ }
+}
diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go
new file mode 100644
index 00000000..6adf879d
--- /dev/null
+++ b/backend/internal/service/affiliate_service_test.go
@@ -0,0 +1,59 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type affiliateSettingRepoStub struct {
+ value string
+ err error
+}
+
+func (s *affiliateSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, s.err }
+func (s *affiliateSettingRepoStub) GetValue(context.Context, string) (string, error) {
+ if s.err != nil {
+ return "", s.err
+ }
+ return s.value, nil
+}
+func (s *affiliateSettingRepoStub) Set(context.Context, string, string) error { return s.err }
+func (s *affiliateSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return map[string]string{}, nil
+}
+func (s *affiliateSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return s.err
+}
+func (s *affiliateSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return map[string]string{}, nil
+}
+func (s *affiliateSettingRepoStub) Delete(context.Context, string) error { return s.err }
+
+func TestAffiliateRebateRatePercentSemantics(t *testing.T) {
+ t.Parallel()
+
+ svc := &AffiliateService{settingRepo: &affiliateSettingRepoStub{value: "1"}}
+ rate := svc.loadAffiliateRebateRatePercent(context.Background())
+ require.Equal(t, 1.0, rate)
+
+ svc.settingRepo = &affiliateSettingRepoStub{value: "0.2"}
+ rate = svc.loadAffiliateRebateRatePercent(context.Background())
+ require.Equal(t, 0.2, rate)
+}
+
+func TestMaskEmail(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
+ require.Equal(t, "x***@d***", maskEmail("x@domain"))
+ require.Equal(t, "", maskEmail(""))
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index e45d8d66..fe0c32f5 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -72,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
+ affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -121,13 +122,26 @@ func (s *AuthService) EntClient() *dbent.Client {
return s.entClient
}
+func (s *AuthService) SetAffiliateService(affiliateService *AffiliateService) {
+ if s == nil {
+ return
+ }
+ s.affiliateService = affiliateService
+}
+
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
}
-// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
-func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
+// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
+// affiliateCode 使用可选参数以兼容旧调用方。
+func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string, affiliateCode ...string) (string, *User, error) {
+ affiliateCodeRaw := ""
+ if len(affiliateCode) > 0 {
+ affiliateCodeRaw = affiliateCode[0]
+ }
+
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -223,6 +237,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ if s.affiliateService != nil {
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
+ }
+ if code := strings.TrimSpace(affiliateCodeRaw); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
+ // 邀请返利码绑定失败不影响注册,只记录日志
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
+ }
+ }
+ }
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index cf47b76f..23afeb87 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -18,6 +18,13 @@ const (
RoleUser = domain.RoleUser
)
+// Affiliate rebate settings
+const (
+ AffiliateRebateRateDefault = 20.0
+ AffiliateRebateRateMin = 0.0
+ AffiliateRebateRateMax = 100.0
+)
+
// Platform constants
const (
PlatformAnthropic = domain.PlatformAnthropic
@@ -87,6 +94,7 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
+ SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 243edff3..c6167447 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"log/slog"
@@ -268,6 +269,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
+ s.applyAffiliateRebateForOrder(ctx, o)
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
@@ -281,6 +283,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
+ s.applyAffiliateRebateForOrder(ctx, o)
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
@@ -358,6 +361,139 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
+func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) {
+ if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
+ return
+ }
+ if s.affiliateService == nil {
+ return
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
+ })
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return
+ }
+ if !claimed {
+ return
+ }
+
+ rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return
+ }
+
+ if rebateAmount <= 0 {
+ if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
+ "baseAmount": o.Amount,
+ "reason": "no inviter bound or rebate amount <= 0",
+ }); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return
+ }
+ if err := tx.Commit(); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
+ })
+ }
+ return
+ }
+
+ if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
+ "baseAmount": o.Amount,
+ "rebateAmount": rebateAmount,
+ }); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return
+ }
+
+ if err := tx.Commit(); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
+ })
+ }
+}
+
+func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
+ if client == nil {
+ return false, errors.New("nil payment client")
+ }
+ oid := strconv.FormatInt(orderID, 10)
+ detail, _ := json.Marshal(map[string]any{
+ "baseAmount": baseAmount,
+ "status": "reserved",
+ })
+ rows, err := client.QueryContext(ctx, `
+INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
+SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW()
+WHERE NOT EXISTS (
+ SELECT 1
+ FROM payment_audit_logs
+ WHERE order_id = $1
+ AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
+)
+ON CONFLICT (order_id, action) DO NOTHING
+RETURNING id`, oid, string(detail))
+ if err != nil {
+ return false, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return false, err
+ }
+ return false, nil
+ }
+ var claimID int64
+ if err := rows.Scan(&claimID); err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
+ if client == nil {
+ return errors.New("nil payment client")
+ }
+ oid := strconv.FormatInt(orderID, 10)
+ detailJSON, _ := json.Marshal(detail)
+ updated, err := client.PaymentAuditLog.Update().
+ Where(
+ paymentauditlog.OrderIDEQ(oid),
+ paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
+ ).
+ SetAction(action).
+ SetDetail(string(detailJSON)).
+ SetOperator("system").
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if updated == 0 {
+ return errors.New("affiliate rebate claim log not found")
+ }
+ return nil
+}
+
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index 97fd76a0..15f6feeb 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -170,17 +170,18 @@ type TopUserStat struct {
// --- Service ---
type PaymentService struct {
- providerMu sync.Mutex
- providersLoaded bool
- entClient *dbent.Client
- registry *payment.Registry
- loadBalancer payment.LoadBalancer
- redeemService *RedeemService
- subscriptionSvc *SubscriptionService
- configService *PaymentConfigService
- userRepo UserRepository
- groupRepo GroupRepository
- resumeService *PaymentResumeService
+ providerMu sync.Mutex
+ providersLoaded bool
+ entClient *dbent.Client
+ registry *payment.Registry
+ loadBalancer payment.LoadBalancer
+ redeemService *RedeemService
+ subscriptionSvc *SubscriptionService
+ configService *PaymentConfigService
+ userRepo UserRepository
+ groupRepo GroupRepository
+ resumeService *PaymentResumeService
+ affiliateService *AffiliateService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
@@ -189,6 +190,13 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load
return svc
}
+func (s *PaymentService) SetAffiliateService(affiliateService *AffiliateService) {
+ if s == nil {
+ return
+ }
+ s.affiliateService = affiliateService
+}
+
// --- Provider Registry ---
// EnsureProviders lazily initializes the provider registry on first call.
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index c79d8949..f3801c48 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"log/slog"
+ "math"
"net/url"
"sort"
"strconv"
@@ -1167,6 +1168,8 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
+ settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
+ updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
@@ -1719,6 +1722,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectUserInfoUsernamePath: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
@@ -1846,6 +1850,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
+ if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
+ result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
+ } else {
+ result.AffiliateRebateRate = AffiliateRebateRateDefault
+ }
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -2130,6 +2139,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result
}
+func clampAffiliateRebateRate(value float64) float64 {
+ if math.IsNaN(value) || math.IsInf(value, 0) {
+ return AffiliateRebateRateDefault
+ }
+ if value < AffiliateRebateRateMin {
+ return AffiliateRebateRateMin
+ }
+ if value > AffiliateRebateRateMax {
+ return AffiliateRebateRateMax
+ }
+ return value
+}
+
func isFalseSettingValue(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index ddd4fff6..8a3bd421 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
+ AffiliateRebateRate float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 86bfc327..d8a6a332 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -391,6 +391,53 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
return svc
}
+func ProvideAuthService(
+ entClient *dbent.Client,
+ userRepo UserRepository,
+ redeemRepo RedeemCodeRepository,
+ refreshTokenCache RefreshTokenCache,
+ cfg *config.Config,
+ settingService *SettingService,
+ emailService *EmailService,
+ turnstileService *TurnstileService,
+ emailQueueService *EmailQueueService,
+ promoService *PromoService,
+ defaultSubAssigner DefaultSubscriptionAssigner,
+ affiliateService *AffiliateService,
+) *AuthService {
+ svc := NewAuthService(
+ entClient,
+ userRepo,
+ redeemRepo,
+ refreshTokenCache,
+ cfg,
+ settingService,
+ emailService,
+ turnstileService,
+ emailQueueService,
+ promoService,
+ defaultSubAssigner,
+ )
+ svc.SetAffiliateService(affiliateService)
+ return svc
+}
+
+func ProvidePaymentService(
+ entClient *dbent.Client,
+ registry *payment.Registry,
+ loadBalancer payment.LoadBalancer,
+ redeemService *RedeemService,
+ subscriptionSvc *SubscriptionService,
+ configService *PaymentConfigService,
+ userRepo UserRepository,
+ groupRepo GroupRepository,
+ affiliateService *AffiliateService,
+) *PaymentService {
+ svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo)
+ svc.SetAffiliateService(affiliateService)
+ return svc
+}
+
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
func ProvideBillingCacheService(
cache BillingCache,
@@ -407,7 +454,7 @@ func ProvideBillingCacheService(
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
- NewAuthService,
+ ProvideAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
@@ -486,8 +533,9 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
+ NewAffiliateService,
ProvidePaymentConfigService,
- NewPaymentService,
+ ProvidePaymentService,
ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
ProvideChannelMonitorService,
diff --git a/backend/migrations/130_add_user_affiliates.sql b/backend/migrations/130_add_user_affiliates.sql
new file mode 100644
index 00000000..d8c001e0
--- /dev/null
+++ b/backend/migrations/130_add_user_affiliates.sql
@@ -0,0 +1,20 @@
+CREATE TABLE IF NOT EXISTS user_affiliates (
+ user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
+ aff_code VARCHAR(32) NOT NULL UNIQUE,
+ inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ aff_count INTEGER NOT NULL DEFAULT 0,
+ aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
+
+COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
+COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
+COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
+COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
+COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
+COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';
diff --git a/backend/migrations/131_affiliate_rebate_hardening.sql b/backend/migrations/131_affiliate_rebate_hardening.sql
new file mode 100644
index 00000000..81e37a9e
--- /dev/null
+++ b/backend/migrations/131_affiliate_rebate_hardening.sql
@@ -0,0 +1,58 @@
+-- 1) Normalize historical affiliate rebate rate values.
+-- Legacy compatibility treated 0 20%).
+-- We now use pure percentage semantics, so convert persisted fractional values once.
+UPDATE settings
+SET value = to_char((value::numeric * 100), 'FM999999990.########'),
+ updated_at = NOW()
+WHERE key = 'affiliate_rebate_rate'
+ AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
+ AND value::numeric > 0
+ AND value::numeric <= 1;
+
+-- 2) Affiliate ledger for accrual/transfer traceability.
+CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ action VARCHAR(32) NOT NULL,
+ amount DECIMAL(20,8) NOT NULL,
+ source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
+
+COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
+COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
+
+-- 3) Enforce idempotency at DB layer for payment audit actions.
+WITH ranked AS (
+ SELECT id,
+ ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
+ FROM payment_audit_logs
+)
+DELETE FROM payment_audit_logs p
+USING ranked r
+WHERE p.id = r.id
+ AND r.rn > 1;
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
+ON payment_audit_logs(order_id, action);
+
+-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
+INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
+SELECT po.id::text,
+ 'AFFILIATE_REBATE_SKIPPED',
+ '{"reason":"baseline before affiliate rebate idempotency rollout"}',
+ 'system',
+ NOW()
+FROM payment_orders po
+WHERE po.order_type = 'balance'
+ AND po.status = 'COMPLETED'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM payment_audit_logs pal
+ WHERE pal.order_id = po.id::text
+ AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
+ );
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index b9f24663..971c2314 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -308,6 +308,7 @@ export interface SystemSettings {
totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
default_balance: number;
+ affiliate_rebate_rate: number;
default_concurrency: number;
default_user_rpm_limit: number;
default_subscriptions: DefaultSubscriptionSetting[];
@@ -489,6 +490,7 @@ export interface UpdateSettingsRequest {
invitation_code_enabled?: boolean;
totp_enabled?: boolean; // TOTP 双因素认证
default_balance?: number;
+ affiliate_rebate_rate?: number;
default_concurrency?: number;
default_user_rpm_limit?: number;
default_subscriptions?: DefaultSubscriptionSetting[];
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index fd3cedb9..da7a91eb 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -9,7 +9,14 @@ import {
prepareOAuthBindAccessTokenCookie,
type WeChatOAuthPublicSettings,
} from './auth'
-import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types'
+import type {
+ User,
+ ChangePasswordRequest,
+ NotifyEmailEntry,
+ UserAuthProvider,
+ UserAffiliateDetail,
+ AffiliateTransferResponse
+} from '@/types'
/**
* Get current user profile
@@ -168,6 +175,16 @@ export async function startOAuthBinding(
window.location.href = startURL
}
+export async function getAffiliateDetail(): Promise {
+ const { data } = await apiClient.get('/user/aff')
+ return data
+}
+
+export async function transferAffiliateQuota(): Promise {
+ const { data } = await apiClient.post('/user/aff/transfer')
+ return data
+}
+
export const userAPI = {
getProfile,
updateProfile,
@@ -180,7 +197,9 @@ export const userAPI = {
bindEmailIdentity,
unbindAuthIdentity,
buildOAuthBindingStartURL,
- startOAuthBinding
+ startOAuthBinding,
+ getAffiliateDetail,
+ transferAffiliateQuota
}
export default userAPI
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 910f24cd..a3a8c30e 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -656,6 +656,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] {
{ path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
+ { path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
path: `/custom/${item.id}`,
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index eeb6087b..e7514f0e 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -346,6 +346,7 @@ export default {
apiKeys: 'API Keys',
usage: 'Usage',
redeem: 'Redeem',
+ affiliate: 'Affiliate Rebates',
profile: 'Profile',
users: 'Users',
groups: 'Groups',
@@ -972,6 +973,47 @@ export default {
}
},
+ affiliate: {
+ title: 'Affiliate Rebates',
+ description: 'Invite new users and convert your rebate quota into account balance',
+ yourCode: 'Your Affiliate Code',
+ inviteLink: 'Invite Link',
+ copyCode: 'Copy Code',
+ copyLink: 'Copy Link',
+ codeCopied: 'Affiliate code copied',
+ linkCopied: 'Invite link copied',
+ loadFailed: 'Failed to load affiliate data',
+ transferFailed: 'Failed to transfer affiliate quota',
+ stats: {
+ invitedUsers: 'Invited Users',
+ availableQuota: 'Available Rebate Quota',
+ totalQuota: 'Historical Rebate Quota'
+ },
+ transfer: {
+ title: 'Transfer Rebate Quota',
+ description: 'Move available rebate quota into your account balance',
+ button: 'Transfer to Balance',
+ transferring: 'Transferring...',
+ empty: 'No available rebate quota',
+ success: '{amount} has been transferred to your balance'
+ },
+ invitees: {
+ title: 'Invited Users',
+ empty: 'No invited users yet',
+ columns: {
+ email: 'Email',
+ username: 'Username',
+ joinedAt: 'Joined At'
+ }
+ },
+ tips: {
+ title: 'How It Works',
+ line1: 'Share your affiliate code or invite link with new users.',
+ line2: 'When invitees recharge, you receive rebate quota based on the configured rate.',
+ line3: 'Transfer rebate quota to balance at any time.'
+ }
+ },
+
// Redeem
redeem: {
title: 'Redeem Code',
@@ -4837,6 +4879,9 @@ export default {
description: 'Default values for new users',
defaultBalance: 'Default Balance',
defaultBalanceHint: 'Initial balance for new users',
+ affiliateRebateRate: 'Affiliate Rebate Rate',
+ affiliateRebateRateHint:
+ 'Rebate percentage credited to inviter after recharge (0-100%, e.g. 10 means 10%)',
defaultConcurrency: 'Default Concurrency',
defaultConcurrencyHint: 'Maximum concurrent requests for new users',
defaultUserRpmLimit: 'Default User RPM Limit',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 5f91a2f6..3057f93e 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -346,6 +346,7 @@ export default {
apiKeys: 'API 密钥',
usage: '使用记录',
redeem: '兑换',
+ affiliate: '邀请返利',
profile: '个人资料',
users: '用户管理',
groups: '分组管理',
@@ -976,6 +977,47 @@ export default {
}
},
+ affiliate: {
+ title: '邀请返利',
+ description: '邀请新用户注册,并将返利额度转入账户余额',
+ yourCode: '我的邀请码',
+ inviteLink: '邀请链接',
+ copyCode: '复制邀请码',
+ copyLink: '复制链接',
+ codeCopied: '邀请码已复制',
+ linkCopied: '邀请链接已复制',
+ loadFailed: '加载邀请返利数据失败',
+ transferFailed: '转入余额失败',
+ stats: {
+ invitedUsers: '邀请人数',
+ availableQuota: '可转返利额度',
+ totalQuota: '历史返利额度'
+ },
+ transfer: {
+ title: '返利额度转余额',
+ description: '将当前可用返利额度一键转入账户余额',
+ button: '转入余额',
+ transferring: '转入中...',
+ empty: '当前没有可转入额度',
+ success: '已转入余额:{amount}'
+ },
+ invitees: {
+ title: '已邀请用户',
+ empty: '暂无邀请记录',
+ columns: {
+ email: '邮箱',
+ username: '用户名',
+ joinedAt: '注册时间'
+ }
+ },
+ tips: {
+ title: '使用说明',
+ line1: '将邀请码或邀请链接分享给新用户。',
+ line2: '被邀请用户充值后,你可获得对应比例的返利额度。',
+ line3: '返利额度可随时转入账户余额。'
+ }
+ },
+
// Redeem
redeem: {
title: '兑换码',
@@ -5000,6 +5042,8 @@ export default {
description: '新用户的默认值',
defaultBalance: '默认余额',
defaultBalanceHint: '新用户的初始余额',
+ affiliateRebateRate: '邀请返利比例',
+ affiliateRebateRateHint: '充值后返给邀请人的比例(0-100%,例如填写 10 表示返利 10%)',
defaultConcurrency: '默认并发数',
defaultConcurrencyHint: '新用户的最大并发请求数',
defaultUserRpmLimit: '默认用户 RPM 限制',
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index dc886b23..2b85b97e 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -197,6 +197,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'redeem.description'
}
},
+ {
+ path: '/affiliate',
+ name: 'Affiliate',
+ component: () => import('@/views/user/AffiliateView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: false,
+ title: 'Affiliate',
+ titleKey: 'affiliate.title',
+ descriptionKey: 'affiliate.description'
+ }
+ },
{
path: '/available-channels',
name: 'UserAvailableChannels',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 2329bb25..e2f41900 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -122,6 +122,29 @@ export interface RegisterRequest {
turnstile_token?: string
promo_code?: string
invitation_code?: string
+ aff_code?: string
+}
+
+export interface AffiliateInvitee {
+ user_id: number
+ email: string
+ username: string
+ created_at?: string
+}
+
+export interface UserAffiliateDetail {
+ user_id: number
+ aff_code: string
+ inviter_id?: number | null
+ aff_count: number
+ aff_quota: number
+ aff_history_quota: number
+ invitees: AffiliateInvitee[]
+}
+
+export interface AffiliateTransferResponse {
+ transferred_quota: number
+ balance: number
}
export interface SendVerifyCodeRequest {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 3e167938..6da4b21a 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2153,6 +2153,31 @@
{{ t("admin.settings.defaults.defaultBalanceHint") }}
+
+
+
+
+ %
+
+
+ {{ t("admin.settings.defaults.affiliateRebateRateHint") }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.openai.compactModeDesc') }}
+
+
+
+
+
+
+
+
+
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
+
+
+
+
+
@@ -2918,7 +2957,8 @@ import type {
AccountPlatform,
AccountType,
CheckMixedChannelResponse,
- CreateAccountRequest
+ CreateAccountRequest,
+ OpenAICompactMode
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
@@ -3059,6 +3099,7 @@ const editWeeklyResetDay = ref(null)
const editWeeklyResetHour = ref(null)
const editResetTimezone = ref(null)
const modelMappings = ref([])
+const openAICompactModelMappings = ref([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -3071,6 +3112,7 @@ const customErrorCodeInput = ref(null)
const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false)
+const openAICompactMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -3112,10 +3154,16 @@ const bedrockApiKeyValue = ref('')
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping')
+const getOpenAICompactModelMappingKey = createStableObjectKeyResolver('create-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver('create-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver('create-temp-unsched-rule')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
+const openAICompactModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
+ { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
+ { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
+])
function buildAntigravityExtra(): Record | undefined {
const extra: Record = {}
@@ -3124,6 +3172,9 @@ function buildAntigravityExtra(): Record | undefined {
return Object.keys(extra).length > 0 ? extra : undefined
}
+const buildOpenAICompactModelMapping = () =>
+ buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
@@ -3489,6 +3540,14 @@ const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
}
+const addOpenAICompactModelMapping = () => {
+ openAICompactModelMappings.value.push({ from: '', to: '' })
+}
+
+const removeOpenAICompactModelMapping = (index: number) => {
+ openAICompactModelMappings.value.splice(index, 1)
+}
+
const removeModelMapping = (index: number) => {
modelMappings.value.splice(index, 1)
}
@@ -3781,6 +3840,7 @@ const resetForm = () => {
editWeeklyResetHour.value = null
editResetTimezone.value = null
modelMappings.value = []
+ openAICompactModelMappings.value = []
modelRestrictionMode.value = 'whitelist'
allowedModels.value = [...claudeModels] // Default fill related models
@@ -3797,6 +3857,7 @@ const resetForm = () => {
interceptWarmupRequests.value = false
autoPauseOnExpired.value = true
openaiPassthroughEnabled.value = false
+ openAICompactMode.value = 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -3874,6 +3935,11 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined
}
@@ -4086,6 +4152,12 @@ const handleSubmit = async () => {
credentials.model_mapping = modelMapping
}
}
+ if (form.platform === 'openai') {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -4198,6 +4270,14 @@ const createAccountAndFinish = async (
finalExtra = quotaExtra
}
}
+ if (platform === 'openai') {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete credentials.compact_model_mapping
+ }
+ }
await doCreateAccount({
name: form.name,
notes: form.notes,
@@ -4252,6 +4332,12 @@ const handleOpenAIExchange = async (authCode: string) => {
credentials.model_mapping = modelMapping
}
}
+ if (shouldCreateOpenAI) {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// 应用临时不可调度配置
if (!applyTempUnschedConfig(credentials)) {
@@ -4344,6 +4430,12 @@ const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string)
credentials.model_mapping = modelMapping
}
}
+ if (shouldCreateOpenAI) {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 59ca0b9c..42211ba7 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1306,6 +1306,64 @@
+
+
+
+
+
+ {{ t('admin.accounts.openai.compactModeDesc') }}
+
+
+
+
+
+
+
+ {{ t(openAICompactStatusKey) }}
+
+ {{ t('admin.accounts.openai.compactLastChecked') }}:
+ {{ formatDateTime(new Date(String(account.extra.openai_compact_checked_at))) }}
+
+
+
+
+
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
+
+
+
+
+
@@ -1849,7 +1907,7 @@ import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
-import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
+import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@@ -1859,7 +1917,7 @@ import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
-import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
+import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import {
OPENAI_WS_MODE_CTX_POOL,
@@ -1934,6 +1992,7 @@ const isBedrockAPIKeyMode = computed(() =>
(props.account?.credentials as Record
)?.auth_mode === 'apikey'
)
const modelMappings = ref([])
+const openAICompactModelMappings = ref([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -1953,6 +2012,7 @@ const antigravityModelMappings = ref([])
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('edit-model-mapping')
+const getOpenAICompactModelMappingKey = createStableObjectKeyResolver('edit-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver('edit-temp-unsched-rule')
@@ -1992,6 +2052,7 @@ const customBaseUrl = ref('')
// OpenAI 自动透传开关(OAuth/API Key)
const openaiPassthroughEnabled = ref(false)
+const openAICompactMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -2045,9 +2106,27 @@ const openaiResponsesWebSocketV2Mode = computed({
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
)
+const openAICompactModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
+ { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
+ { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
+])
const isOpenAIModelRestrictionDisabled = computed(() =>
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
)
+const openAICompactStatusKey = computed(() => {
+ const extra = props.account?.extra as Record | undefined
+ if (!props.account || props.account.platform !== 'openai') return ''
+ const mode = typeof extra?.openai_compact_mode === 'string' ? extra.openai_compact_mode : 'auto'
+ if (mode === 'force_on') return 'admin.accounts.openai.compactSupported'
+ if (mode === 'force_off') return 'admin.accounts.openai.compactUnsupported'
+ if (typeof extra?.openai_compact_supported === 'boolean') {
+ return extra.openai_compact_supported
+ ? 'admin.accounts.openai.compactSupported'
+ : 'admin.accounts.openai.compactUnsupported'
+ }
+ return 'admin.accounts.openai.compactUnknown'
+})
// Computed: current preset mappings based on platform
const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic'))
@@ -2177,6 +2256,8 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
openaiPassthroughEnabled.value = false
+ openAICompactMode.value = 'auto'
+ openAICompactModelMappings.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -2184,6 +2265,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
webSearchEmulationMode.value = 'default'
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
+ openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
modeKey: 'openai_oauth_responses_websockets_v2_mode',
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
@@ -2199,6 +2281,11 @@ const syncFormFromAccount = (newAccount: Account | null) => {
if (newAccount.type === 'oauth') {
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
}
+ const credentials = newAccount.credentials as Record | undefined
+ const compactMappings = credentials?.compact_model_mapping as Record | undefined
+ if (compactMappings && typeof compactMappings === 'object') {
+ openAICompactModelMappings.value = Object.entries(compactMappings).map(([from, to]) => ({ from, to }))
+ }
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
@@ -2423,6 +2510,15 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editApiKey.value = ''
}
+async function loadTLSProfiles() {
+ try {
+ const profiles = await adminAPI.tlsFingerprintProfiles.list()
+ tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
+ } catch {
+ tlsFingerprintProfiles.value = []
+ }
+}
+
watch(
[() => props.show, () => props.account],
([show, newAccount], [wasShow, previousAccount]) => {
@@ -2437,15 +2533,6 @@ watch(
{ immediate: true }
)
-const loadTLSProfiles = async () => {
- try {
- const profiles = await adminAPI.tlsFingerprintProfiles.list()
- tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
- } catch {
- tlsFingerprintProfiles.value = []
- }
-}
-
// Model mapping helpers
const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
@@ -2468,6 +2555,14 @@ const addAntigravityModelMapping = () => {
antigravityModelMappings.value.push({ from: '', to: '' })
}
+const addOpenAICompactModelMapping = () => {
+ openAICompactModelMappings.value.push({ from: '', to: '' })
+}
+
+const removeOpenAICompactModelMapping = (index: number) => {
+ openAICompactModelMappings.value.splice(index, 1)
+}
+
const removeAntigravityModelMapping = (index: number) => {
antigravityModelMappings.value.splice(index, 1)
}
@@ -2911,6 +3006,14 @@ const handleSubmit = async () => {
} else if (currentCredentials.model_mapping) {
newCredentials.model_mapping = currentCredentials.model_mapping
}
+ if (props.account.platform === 'openai') {
+ const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+ if (compactModelMapping) {
+ newCredentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete newCredentials.compact_model_mapping
+ }
+ }
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -3036,6 +3139,12 @@ const handleSubmit = async () => {
// 透传模式保留现有映射
newCredentials.model_mapping = currentCredentials.model_mapping
}
+ const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+ if (compactModelMapping) {
+ newCredentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete newCredentials.compact_model_mapping
+ }
updatePayload.credentials = newCredentials
}
@@ -3208,6 +3317,11 @@ const handleSubmit = async () => {
delete newExtra.openai_passthrough
delete newExtra.openai_oauth_passthrough
}
+ if (openAICompactMode.value === 'auto') {
+ delete newExtra.openai_compact_mode
+ } else {
+ newExtra.openai_compact_mode = openAICompactMode.value
+ }
if (props.account.type === 'oauth') {
if (codexCLIOnlyEnabled.value) {
diff --git a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
index 7cdf7999..f758e6b0 100644
--- a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
+++ b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
@@ -122,7 +122,7 @@ describe('AccountStatusIndicator', () => {
}
})
- expect(wrapper.text()).toContain('account.creditsExhausted')
+ expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
it('模型限流 + overages 启用 + AICredits key 生效 → 普通限流样式(积分耗尽,无 ⚡)', () => {
@@ -157,6 +157,6 @@ describe('AccountStatusIndicator', () => {
expect(wrapper.text()).toContain('CSon45')
expect(wrapper.text()).not.toContain('⚡')
// AICredits 积分耗尽状态应显示
- expect(wrapper.text()).toContain('account.creditsExhausted')
+ expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
})
diff --git a/frontend/src/components/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts
new file mode 100644
index 00000000..c82a3840
--- /dev/null
+++ b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts
@@ -0,0 +1,150 @@
+import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+import { defineComponent } from 'vue'
+import AccountTestModal from '../AccountTestModal.vue'
+
+const { getAvailableModelsMock } = vi.hoisted(() => ({
+ getAvailableModelsMock: vi.fn()
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ accounts: {
+ getAvailableModels: getAvailableModelsMock
+ }
+ }
+}))
+
+vi.mock('@/composables/useClipboard', () => ({
+ useClipboard: () => ({
+ copyToClipboard: vi.fn()
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+const BaseDialogStub = defineComponent({
+ name: 'BaseDialog',
+ props: { show: { type: Boolean, default: false } },
+ template: '
'
+})
+
+const SelectStub = defineComponent({
+ name: 'SelectStub',
+ props: {
+ modelValue: { type: [String, Number, Boolean, null], default: '' },
+ options: { type: Array, default: () => [] },
+ valueKey: { type: String, default: 'value' },
+ labelKey: { type: String, default: 'label' }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+ `
+})
+
+const TextAreaStub = defineComponent({
+ name: 'TextArea',
+ props: {
+ modelValue: { type: String, default: '' }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+ `
+})
+
+function buildAccount() {
+ return {
+ id: 1,
+ name: 'OpenAI OAuth',
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ credentials: {},
+ extra: {},
+ concurrency: 1,
+ priority: 1,
+ proxy_id: null,
+ auto_pause_on_expired: false
+ } as any
+}
+
+describe('AccountTestModal', () => {
+ const originalFetch = global.fetch
+
+ beforeEach(() => {
+ getAvailableModelsMock.mockReset()
+ getAvailableModelsMock.mockResolvedValue([
+ { id: 'gpt-5.4', display_name: 'GPT-5.4' }
+ ])
+ global.fetch = vi.fn().mockResolvedValue({
+ ok: true,
+ body: {
+ getReader: () => ({
+ read: vi.fn().mockResolvedValue({ done: true, value: undefined })
+ })
+ }
+ } as any)
+ localStorage.setItem('auth_token', 'test-token')
+ })
+
+ afterEach(() => {
+ global.fetch = originalFetch
+ localStorage.clear()
+ })
+
+ it('posts compact mode for OpenAI compact probe', async () => {
+ const wrapper = mount(AccountTestModal, {
+ props: {
+ show: true,
+ account: buildAccount()
+ },
+ global: {
+ stubs: {
+ BaseDialog: BaseDialogStub,
+ Select: SelectStub,
+ TextArea: TextAreaStub,
+ Icon: true
+ }
+ }
+ })
+
+ await flushPromises()
+ ;(wrapper.vm as any).selectedModelId = 'gpt-5.4'
+ ;(wrapper.vm as any).testMode = 'compact'
+ await (wrapper.vm as any).startTest()
+ await flushPromises()
+
+ expect(global.fetch).toHaveBeenCalledTimes(1)
+ const [, options] = (global.fetch as any).mock.calls[0]
+ expect(JSON.parse(options.body)).toMatchObject({
+ model_id: 'gpt-5.4',
+ mode: 'compact'
+ })
+ })
+})
diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
index e3260168..c4e2a9bc 100644
--- a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
+++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
@@ -26,6 +26,13 @@ vi.mock('@/api/admin', () => ({
accounts: {
update: updateAccountMock,
checkMixedChannelRisk: checkMixedChannelRiskMock
+ },
+ settings: {
+ getWebSearchEmulationConfig: vi.fn().mockResolvedValue({ enabled: false, providers: [] }),
+ getSettings: vi.fn().mockResolvedValue({})
+ },
+ tlsFingerprintProfiles: {
+ list: vi.fn().mockResolvedValue([])
}
}
}))
@@ -82,6 +89,32 @@ const ModelWhitelistSelectorStub = defineComponent({
`
})
+const SelectStub = defineComponent({
+ name: 'SelectStub',
+ props: {
+ modelValue: {
+ type: [String, Number, Boolean, null],
+ default: ''
+ },
+ options: {
+ type: Array,
+ default: () => []
+ }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+ `
+})
+
function buildAccount() {
return {
id: 1,
@@ -119,7 +152,7 @@ function mountModal(account = buildAccount()) {
global: {
stubs: {
BaseDialog: BaseDialogStub,
- Select: true,
+ Select: SelectStub,
Icon: true,
ProxySelector: true,
GroupSelector: true,
@@ -156,4 +189,31 @@ describe('EditAccountModal', () => {
'gpt-5.2': 'gpt-5.2'
})
})
+
+ it('submits OpenAI compact mode and compact-only model mapping', async () => {
+ const account = buildAccount()
+ account.extra = {
+ openai_compact_mode: 'force_on'
+ }
+ account.credentials = {
+ ...account.credentials,
+ compact_model_mapping: {
+ 'gpt-5.4': 'gpt-5.4-openai-compact'
+ }
+ }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_compact_mode).toBe('force_on')
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.compact_model_mapping).toEqual({
+ 'gpt-5.4': 'gpt-5.4-openai-compact'
+ })
+ })
})
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index e7514f0e..5aa63e6a 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2848,6 +2848,22 @@ export default {
codexCLIOnly: 'Codex official clients only',
codexCLIOnlyDesc:
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
+ compactMode: 'Compact mode',
+ compactModeDesc:
+ 'Controls how this account participates in /responses/compact routing. Auto follows probe results, Force On always allows, Force Off always excludes.',
+ compactModeAuto: 'Auto',
+ compactModeForceOn: 'Force On',
+ compactModeForceOff: 'Force Off',
+ compactModelMapping: 'Compact-only model mapping',
+ compactModelMappingDesc:
+ 'Only applies to /responses/compact. Use this when the upstream compact endpoint requires a special compact model.',
+ compactSupported: 'Compact supported',
+ compactUnsupported: 'Compact unsupported',
+ compactUnknown: 'Compact unknown',
+ compactLastChecked: 'Last compact probe',
+ testMode: 'Test mode',
+ testModeDefault: 'Default request',
+ testModeCompact: 'Compact probe',
modelRestrictionDisabledByPassthrough: 'Automatic passthrough is enabled: model whitelist/mapping will not take effect.',
},
anthropic: {
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 3057f93e..b61248ff 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2993,6 +2993,22 @@ export default {
responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
codexCLIOnly: '仅允许 Codex 官方客户端',
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
+ compactMode: 'Compact 模式',
+ compactModeDesc:
+ '控制本账号在 /responses/compact 调度中的参与方式。Auto 跟随探测结果,Force On 强制允许,Force Off 强制排除。',
+ compactModeAuto: '自动',
+ compactModeForceOn: '强制开启',
+ compactModeForceOff: '强制关闭',
+ compactModelMapping: 'Compact 专属模型映射',
+ compactModelMappingDesc:
+ '仅在 /responses/compact 请求中生效。当上游 compact 端点需要特殊 compact 模型时使用。',
+ compactSupported: '支持 Compact',
+ compactUnsupported: '不支持 Compact',
+ compactUnknown: 'Compact 未知',
+ compactLastChecked: '最近探测',
+ testMode: '测试模式',
+ testModeDefault: '常规请求',
+ testModeCompact: 'Compact 探测',
modelRestrictionDisabledByPassthrough: '已开启自动透传:模型白名单/映射不会生效。',
},
anthropic: {
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index e2f41900..50b4353e 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -767,8 +767,8 @@ export interface Account {
platform: AccountPlatform
type: AccountType
credentials?: Record
- // Extra fields including Codex usage and model-level rate limits (Antigravity smart retry)
- extra?: (CodexUsageSnapshot & {
+ // Extra fields including Codex usage, OpenAI compact capability, and model-level rate limits.
+ extra?: (CodexUsageSnapshot & OpenAICompactState & {
model_rate_limits?: Record
antigravity_credits_overages?: Record
} & Record)
@@ -940,6 +940,16 @@ export interface CodexUsageSnapshot {
codex_usage_updated_at?: string // Last update timestamp
}
+export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off'
+
+export interface OpenAICompactState {
+ openai_compact_mode?: OpenAICompactMode
+ openai_compact_supported?: boolean
+ openai_compact_checked_at?: string
+ openai_compact_last_status?: number
+ openai_compact_last_error?: string
+}
+
export interface CreateAccountRequest {
name: string
notes?: string | null
diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue
index 4fec956b..bc4c6215 100644
--- a/frontend/src/views/admin/AccountsView.vue
+++ b/frontend/src/views/admin/AccountsView.vue
@@ -188,6 +188,13 @@
+
+ {{ getOpenAICompactLabel(row) }}
+
| undefined
+ const mode = typeof extra?.openai_compact_mode === 'string' ? extra.openai_compact_mode : 'auto'
+ if (mode === 'force_on') return 'supported'
+ if (mode === 'force_off') return 'unsupported'
+ if (typeof extra?.openai_compact_supported === 'boolean') {
+ return extra.openai_compact_supported ? 'supported' : 'unsupported'
+ }
+ return 'unknown'
+}
+
+function getOpenAICompactLabel(row: any): string | null {
+ switch (getOpenAICompactState(row)) {
+ case 'supported': return t('admin.accounts.openai.compactSupported')
+ case 'unsupported': return t('admin.accounts.openai.compactUnsupported')
+ case 'unknown': return t('admin.accounts.openai.compactUnknown')
+ default: return null
+ }
+}
+
+function getOpenAICompactClass(row: any): string {
+ switch (getOpenAICompactState(row)) {
+ case 'supported': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/40 dark:text-emerald-300'
+ case 'unsupported': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/40 dark:text-rose-300'
+ case 'unknown': return 'bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300'
+ default: return ''
+ }
+}
+
+function getOpenAICompactTitle(row: any): string {
+ const extra = row.extra as Record | undefined
+ const checkedAt = typeof extra?.openai_compact_checked_at === 'string' ? extra.openai_compact_checked_at : ''
+ if (!checkedAt) return getOpenAICompactLabel(row) || ''
+ return `${getOpenAICompactLabel(row)} | ${t('admin.accounts.openai.compactLastChecked')}: ${formatDateTime(new Date(checkedAt))}`
+}
+
function getAntigravityTierClass(row: any): string {
const tier = getAntigravityTierFromRow(row)
switch (tier) {
From 5b63a9b02d7ddc99fa31fed727361c8e97922de2 Mon Sep 17 00:00:00 2001
From: AyeSt0
Date: Sat, 25 Apr 2026 15:09:40 +0800
Subject: [PATCH 27/33] fix(openai): fail over before responses stream output
---
.../service/openai_gateway_service.go | 255 ++++++++++++++++--
.../service/openai_gateway_service_test.go | 191 ++++++++++++-
2 files changed, 428 insertions(+), 18 deletions(-)
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 5db273b4..75a92f6e 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -3147,6 +3147,113 @@ type openaiStreamingResultPassthrough struct {
firstTokenMs *int
}
+func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
+ if localStarted {
+ return true
+ }
+ return c != nil && c.Writer != nil && c.Writer.Written()
+}
+
+func openAIStreamEventIsPreamble(eventType string) bool {
+ switch strings.TrimSpace(eventType) {
+ case "response.created", "response.in_progress":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIStreamDataStartsClientOutput(data, eventType string) bool {
+ trimmed := strings.TrimSpace(data)
+ if trimmed == "" {
+ return false
+ }
+ if strings.TrimSpace(eventType) == "response.failed" {
+ return false
+ }
+ return !openAIStreamEventIsPreamble(eventType)
+}
+
+func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
+ code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
+ if code == "" {
+ code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
+ }
+ errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
+ if errType == "" {
+ errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
+ }
+ combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
+ if combined == "" {
+ return true
+ }
+ nonRetryableMarkers := []string{
+ "invalid_request",
+ "content_policy",
+ "policy",
+ "safety",
+ "high-risk cyber",
+ "not allowed",
+ "violat",
+ }
+ for _, marker := range nonRetryableMarkers {
+ if strings.Contains(combined, marker) {
+ return false
+ }
+ }
+ return true
+}
+
+func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
+ c *gin.Context,
+ account *Account,
+ passthrough bool,
+ upstreamRequestID string,
+ payload []byte,
+ message string,
+) *UpstreamFailoverError {
+ message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
+ if message == "" {
+ message = "OpenAI stream disconnected before completion"
+ }
+ detail := ""
+ if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
+ maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
+ if maxBytes <= 0 {
+ maxBytes = 2048
+ }
+ detail = truncateString(string(payload), maxBytes)
+ }
+ if c != nil {
+ setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
+ event := OpsUpstreamErrorEvent{
+ Platform: PlatformOpenAI,
+ UpstreamStatusCode: http.StatusBadGateway,
+ UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
+ Passthrough: passthrough,
+ Kind: "failover",
+ Message: message,
+ Detail: detail,
+ }
+ if account != nil {
+ event.Platform = account.Platform
+ event.AccountID = account.ID
+ event.AccountName = account.Name
+ }
+ appendOpsUpstreamError(c, event)
+ }
+ body, _ := json.Marshal(gin.H{
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": message,
+ },
+ })
+ return &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: body,
+ }
+}
+
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
@@ -3178,7 +3285,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
+ sawFailedEvent := false
+ failedMessage := ""
+ clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
+ pendingLines := make([]string, 0, 8)
+ writePendingLines := func() bool {
+ for _, pending := range pendingLines {
+ if _, err := fmt.Fprintln(w, pending); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
+ return false
+ }
+ }
+ pendingLines = pendingLines[:0]
+ return true
+ }
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -3193,6 +3315,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
+ lineStartsClientOutput := false
+ forceFlushFailedEvent := false
if data, ok := extractOpenAISSEDataLine(line); ok {
dataBytes := []byte(data)
trimmedData := strings.TrimSpace(data)
@@ -3203,13 +3327,24 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
trimmedData = strings.TrimSpace(replacedData)
}
}
+ eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
+ if eventType == "response.failed" {
+ failedMessage = extractOpenAISSEErrorMessage(dataBytes)
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
+ }
+ forceFlushFailedEvent = true
+ sawFailedEvent = true
+ }
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
- if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
+ lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
+ if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -3217,20 +3352,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if !clientDisconnected {
+ if !clientOutputStarted && !lineStartsClientOutput {
+ pendingLines = append(pendingLines, line)
+ continue
+ }
+ if !clientOutputStarted && len(pendingLines) > 0 {
+ if !writePendingLines() {
+ continue
+ }
+ }
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
+ clientOutputStarted = true
flusher.Flush()
}
}
}
if err := scanner.Err(); err != nil {
- if sawTerminalEvent {
+ if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
- if clientDisconnected {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
+ if sawFailedEvent {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
@@ -3239,6 +3384,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ msg := "OpenAI stream disconnected before completion"
+ if errText := strings.TrimSpace(err.Error()); errText != "" {
+ msg += ": " + errText
+ }
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
+ }
+ if clientDisconnected {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
+ }
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
@@ -3247,12 +3403,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
+ if sawFailedEvent {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
+ }
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
+ }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
@@ -3854,6 +4017,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
+ sawFailedEvent := false
+ failedMessage := ""
+ clientOutputStarted := false
+ upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
+ var streamFailoverErr error
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
@@ -3870,7 +4038,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if err := flushBuffered(); err != nil {
clientDisconnected = true
+ return
}
+ clientOutputStarted = true
}
needModelReplace := originalModel != mappedModel
@@ -3878,43 +4048,72 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
}
finalizeStream := func() (*openaiStreamingResult, error) {
+ if !sawTerminalEvent {
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ return resultWithUsage(), s.newOpenAIStreamFailoverError(
+ c,
+ account,
+ false,
+ upstreamRequestID,
+ nil,
+ "OpenAI stream ended before a terminal event",
+ )
+ }
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
+ }
+ if sawFailedEvent {
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
+ }
if !clientDisconnected {
+ hadBufferedData := bufferedWriter.Buffered() > 0
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
+ } else if hadBufferedData {
+ clientOutputStarted = true
}
}
- if !sawTerminalEvent {
- return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
- }
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
- if sawTerminalEvent {
+ if sawTerminalEvent && !sawFailedEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
+ if sawFailedEvent {
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
+ }
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
- // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
- if clientDisconnected {
- return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
- }
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
sendErrorEvent("response_too_large")
return resultWithUsage(), scanErr, true
}
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ msg := "OpenAI stream disconnected before completion"
+ if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
+ msg += ": " + errText
+ }
+ return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
+ }
+ // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
+ if clientDisconnected {
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
+ }
sendErrorEvent("stream_read_error")
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
}
processSSELine := func(line string, queueDrained bool) {
+ if streamFailoverErr != nil {
+ return
+ }
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
@@ -3930,18 +4129,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
+ eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
+ forceFlushFailedEvent := false
+ if eventType == "response.failed" {
+ failedMessage = extractOpenAISSEErrorMessage(dataBytes)
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
+ sawFailedEvent = true
+ streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
+ return
+ }
+ forceFlushFailedEvent = true
+ sawFailedEvent = true
+ }
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
dataBytes = correctedData
data = string(correctedData)
line = "data: " + data
+ eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
+ startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
- shouldFlush := queueDrained
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
+ if firstTokenMs == nil && startsClientOutput {
// 保证首个 token 事件尽快出站,避免影响 TTFT。
shouldFlush = true
}
@@ -3955,12 +4168,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ } else {
+ clientOutputStarted = true
}
}
}
// Record first token time
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ if firstTokenMs == nil && startsClientOutput {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -3976,10 +4191,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
- } else if queueDrained {
+ } else if queueDrained && clientOutputStarted {
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ } else {
+ clientOutputStarted = true
}
}
}
@@ -3990,6 +4207,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
processSSELine(scanner.Text(), true)
+ if streamFailoverErr != nil {
+ return resultWithUsage(), streamFailoverErr
+ }
}
if result, err, done := handleScanErr(scanner.Err()); done {
return result, err
@@ -4039,6 +4259,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return result, err
}
processSSELine(ev.line, len(events) == 0)
+ if streamFailoverErr != nil {
+ return resultWithUsage(), streamFailoverErr
+ }
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 8b7945bc..0cf2392d 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -93,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
+type errReadCloser struct {
+ err error
+}
+
+func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
+func (r errReadCloser) Close() error { return nil }
+
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -1003,6 +1010,150 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
+func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: errReadCloser{err: io.ErrUnexpectedEOF},
+ Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.in_progress",
+ `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.in_progress",
+ `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.False(t, errors.As(err, &failoverErr))
+ require.True(t, c.Writer.Written())
+ require.Contains(t, rec.Body.String(), "response.failed")
+ require.Contains(t, rec.Body.String(), "high-risk cyber activity")
+}
+
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1072,7 +1223,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
- _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1104,7 +1255,7 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
- _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
@@ -1114,6 +1265,42 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
}
}
+func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
From 9d1751ec57957547b0eb5d3bdf758c2ec82a3e37 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Sat, 25 Apr 2026 08:06:21 +0000
Subject: [PATCH 28/33] chore: sync VERSION to 0.1.118 [skip ci]
---
backend/cmd/server/VERSION | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 8b060688..1fcba8fa 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.117
+0.1.118
From 8987e0ba67c0d1c6dfbb86cf61088ee3c2af5df5 Mon Sep 17 00:00:00 2001
From: hungryboy1025
Date: Sat, 25 Apr 2026 16:56:50 +0800
Subject: [PATCH 29/33] fix(openai): tighten responses stream account tests
---
.../internal/service/account_test_service.go | 28 ++++++++++++---
.../account_test_service_openai_test.go | 25 +++++++++++++
backend/internal/service/gateway_service.go | 2 +-
.../service/openai_gateway_service.go | 5 +--
.../service/openai_gateway_service_test.go | 35 +++++++++++++++++++
5 files changed, 87 insertions(+), 8 deletions(-)
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index d78dcd79..07646474 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -1145,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
+ seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
@@ -1163,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
@@ -1180,9 +1187,20 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
- case "response.completed":
+ case "response.completed", "response.done":
+ seenCompleted = true
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
+ case "response.failed":
+ errorMsg := "OpenAI response failed"
+ if responseData, ok := data["response"].(map[string]any); ok {
+ if errData, ok := responseData["error"].(map[string]any); ok {
+ if msg, ok := errData["message"].(string); ok && msg != "" {
+ errorMsg = msg
+ }
+ }
+ }
+ return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 7202799d..56204be3 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -125,6 +125,31 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
+func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, recorder := newTestContext()
+
+ resp := newJSONResponse(http.StatusOK, "")
+ resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
+
+`))
+
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 90,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Contains(t, recorder.Body.String(), "response.completed")
+ require.NotContains(t, recorder.Body.String(), `"success":true`)
+}
+
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 1713e561..ffd66fc7 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -119,7 +119,7 @@ func openAIStreamEventIsTerminal(data string) bool {
return true
}
switch gjson.Get(trimmed, "type").String() {
- case "response.completed", "response.done", "response.failed":
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 75a92f6e..50e00c01 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -4372,7 +4372,8 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
return
}
eventType := gjson.GetBytes(data, "type").String()
- if eventType != "response.completed" && eventType != "response.done" {
+ if eventType != "response.completed" && eventType != "response.done" &&
+ eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
return
}
@@ -4519,7 +4520,7 @@ func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
- case "response.completed", "response.done", "response.failed":
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return eventType, []byte(data), true
}
}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 0cf2392d..154b7908 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -1336,6 +1336,41 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
+func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
+ }()
+
+ result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
+ _ = pr.Close()
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.usage)
+ require.Equal(t, 2, result.usage.InputTokens)
+ require.Equal(t, 3, result.usage.OutputTokens)
+ require.Equal(t, 1, result.usage.CacheReadInputTokens)
+}
+
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
From dac6e520911c3d353132a75f802f1de8e5caeb34 Mon Sep 17 00:00:00 2001
From: gaoren002
Date: Sat, 25 Apr 2026 12:11:27 +0000
Subject: [PATCH 30/33] fix(openai): keep responses stream alive during
pre-output failover
---
.../service/openai_gateway_service.go | 16 +++++---
.../service/openai_gateway_service_test.go | 40 +++++++++++++++++++
2 files changed, 51 insertions(+), 5 deletions(-)
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 75a92f6e..5034a407 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -4008,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
- // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
- lastDataAt := time.Now()
+ // Track downstream writes separately from upstream reads: pre-output failover
+ // can buffer response.created / response.in_progress, so keepalive must be
+ // based on downstream idle time.
+ lastDownstreamWriteAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
@@ -4041,6 +4043,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
}
clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
needModelReplace := originalModel != mappedModel
@@ -4071,6 +4074,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
} else if hadBufferedData {
clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
return resultWithUsage(), nil
@@ -4114,8 +4118,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if streamFailoverErr != nil {
return
}
- lastDataAt = time.Now()
-
// Extract data from SSE line (supports both "data: " and "data:" formats)
if data, ok := extractOpenAISSEDataLine(line); ok {
@@ -4170,6 +4172,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
}
@@ -4197,6 +4200,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
}
@@ -4283,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if clientDisconnected {
continue
}
- if time.Since(lastDataAt) < keepaliveInterval {
+ if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
@@ -4294,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
+ } else {
+ lastDownstreamWriteAt = time.Now()
}
}
}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 0cf2392d..d54b00ab 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -1117,6 +1117,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T)
require.Empty(t, rec.Body.String())
}
+func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 1,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
+ for i := 0; i < 6; i++ {
+ time.Sleep(250 * time.Millisecond)
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
+ }
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
+ }()
+
+ result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ _ = pr.Close()
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), ":\n\n")
+ require.Contains(t, rec.Body.String(), "response.completed")
+}
+
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
From 4e1bb2b4453f03e2086f77767634337e61b7c4f6 Mon Sep 17 00:00:00 2001
From: shaw
Date: Sat, 25 Apr 2026 19:14:34 +0800
Subject: [PATCH 31/33] feat(affiliate): add feature toggle and per-user custom
invite settings
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 在系统设置「功能开关」中新增邀请返利总开关,默认关闭;
关闭态:菜单隐藏、注册忽略 aff、新充值不返利,但已有 quota 仍可转余额
- 支持管理员为指定用户设置专属邀请码(覆盖随机码,全局唯一)
- 支持管理员为指定用户设置专属返利比例(覆盖全局比例,可单条/批量调整)
- 在系统设置邀请返利卡片内嵌入专属用户管理表格(搜索/编辑/批量/删除),
删除采用项目通用 ConfirmDialog,会同时清除专属比例并把邀请码重置为系统随机码
- /affiliate 用户页新增「我的返利比例」卡片与动态使用说明,让用户直观看到
分享后能拿到多少(同源 resolveRebateRatePercent 计算,与实际充值一致)
- 新增数据库迁移 132 添加 aff_rebate_rate_percent 与 aff_code_custom 列
- 新增 admin 路由组 /api/v1/admin/affiliates/users/* 共 5 个端点
- AffiliateService 改为只依赖 *SettingService,去除冗余的 SettingRepository
- 邀请码格式校验放宽到 [A-Z0-9_-]{4,32},兼容旧 12 位系统码与新自定义码
- 补充单元测试与集成测试覆盖新方法、冲突路径与边界值
---
backend/cmd/server/wire_gen.go | 5 +-
.../handler/admin/affiliate_handler.go | 183 +++++
.../internal/handler/admin/setting_handler.go | 16 +
backend/internal/handler/dto/settings.go | 5 +
backend/internal/handler/handler.go | 1 +
backend/internal/handler/setting_handler.go | 2 +
backend/internal/handler/wire.go | 3 +
backend/internal/repository/affiliate_repo.go | 244 ++++++
.../affiliate_repo_integration_test.go | 215 +++++
backend/internal/server/api_contract_test.go | 2 +
backend/internal/server/routes/admin.go | 18 +
backend/internal/service/affiliate_service.go | 254 ++++--
.../service/affiliate_service_test.go | 132 ++--
backend/internal/service/domain_constants.go | 2 +
backend/internal/service/setting_service.go | 38 +
backend/internal/service/settings_view.go | 4 +
.../132_affiliate_custom_settings.sql | 16 +
frontend/src/api/admin/affiliates.ts | 108 +++
frontend/src/api/admin/index.ts | 7 +-
frontend/src/api/admin/settings.ts | 6 +
frontend/src/components/layout/AppSidebar.vue | 3 +-
frontend/src/i18n/locales/en.ts | 53 +-
frontend/src/i18n/locales/zh.ts | 53 +-
frontend/src/stores/app.ts | 1 +
frontend/src/types/index.ts | 3 +
frontend/src/utils/featureFlags.ts | 5 +
frontend/src/views/admin/SettingsView.vue | 744 +++++++++++++++++-
frontend/src/views/user/AffiliateView.vue | 28 +-
28 files changed, 2010 insertions(+), 141 deletions(-)
create mode 100644 backend/internal/handler/admin/affiliate_handler.go
create mode 100644 backend/migrations/132_affiliate_custom_settings.sql
create mode 100644 frontend/src/api/admin/affiliates.ts
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index d0b1d3af..f767bbea 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -70,7 +70,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
affiliateRepository := repository.NewAffiliateRepository(client, db)
- affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
+ affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
@@ -231,7 +231,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler)
+ affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go
new file mode 100644
index 00000000..97e649ec
--- /dev/null
+++ b/backend/internal/handler/admin/affiliate_handler.go
@@ -0,0 +1,183 @@
+package admin
+
+import (
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AffiliateHandler handles admin affiliate (邀请返利) management:
+// listing users with custom settings, updating per-user invite codes
+// and exclusive rebate rates, and batch operations.
+type AffiliateHandler struct {
+ affiliateService *service.AffiliateService
+ adminService service.AdminService
+}
+
+// NewAffiliateHandler creates a new admin affiliate handler.
+func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
+ return &AffiliateHandler{
+ affiliateService: affiliateService,
+ adminService: adminService,
+ }
+}
+
+// ListUsers returns paginated users with custom affiliate settings.
+// GET /api/v1/admin/affiliates/users
+func (h *AffiliateHandler) ListUsers(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ search := c.Query("search")
+
+ entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
+ Search: search,
+ Page: page,
+ PageSize: pageSize,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, entries, total, page, pageSize)
+}
+
+// UpdateUserSettings updates a user's affiliate settings.
+// PUT /api/v1/admin/affiliates/users/:user_id
+//
+// Both fields are optional and applied independently.
+type UpdateAffiliateUserRequest struct {
+ AffCode *string `json:"aff_code"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ // ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
+ // Used to disambiguate from "field not provided".
+ ClearRebateRate bool `json:"clear_rebate_rate"`
+}
+
+func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+
+ var req UpdateAffiliateUserRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if req.AffCode != nil {
+ if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ if req.ClearRebateRate {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ } else if req.AffRebateRatePercent != nil {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
+// the exclusive rebate rate AND regenerates the invite code as a new system
+// random one. Conceptually this "removes the user from the custom list".
+//
+// Both writes happen in this handler; failure of one leaves the other applied,
+// but the operation is idempotent so the admin can re-run it safely.
+// DELETE /api/v1/admin/affiliates/users/:user_id
+func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
+//
+// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
+// ignored). Otherwise aff_rebate_rate_percent is required and applied to
+// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
+// can't distinguish a missing field from `null`, and a silent clear from a
+// frontend that forgot to include the rate would be a footgun.
+//
+// POST /api/v1/admin/affiliates/users/batch-rate
+type BatchSetRateRequest struct {
+ UserIDs []int64 `json:"user_ids" binding:"required"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ Clear bool `json:"clear"`
+}
+
+func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
+ var req BatchSetRateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if len(req.UserIDs) == 0 {
+ response.BadRequest(c, "user_ids cannot be empty")
+ return
+ }
+ if !req.Clear && req.AffRebateRatePercent == nil {
+ response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
+ return
+ }
+ rate := req.AffRebateRatePercent
+ if req.Clear {
+ rate = nil
+ }
+ if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": len(req.UserIDs)})
+}
+
+// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
+// shared with the frontend's add-custom-user picker.
+type AffiliateUserSummary struct {
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+}
+
+// LookupUsers searches users by email/username for the "add custom user" modal.
+// GET /api/v1/admin/affiliates/users/lookup?q=
+func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
+ keyword := c.Query("q")
+ if keyword == "" {
+ response.Success(c, []AffiliateUserSummary{})
+ return
+ }
+ users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ result := make([]AffiliateUserSummary, len(users))
+ for i, u := range users {
+ result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 2d4dcb5b..40bf1c69 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -242,6 +242,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
@@ -441,6 +443,9 @@ type UpdateSettingsRequest struct {
// Available Channels feature switch (user-facing)
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled *bool `json:"affiliate_enabled"`
}
// UpdateSettings 更新系统设置
@@ -1265,6 +1270,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AvailableChannelsEnabled
}(),
+ AffiliateEnabled: func() bool {
+ if req.AffiliateEnabled != nil {
+ return *req.AffiliateEnabled
+ }
+ return previousSettings.AffiliateEnabled
+ }(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
@@ -1502,6 +1513,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@@ -1870,6 +1883,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
changed = append(changed, "available_channels_enabled")
}
+ if before.AffiliateEnabled != after.AffiliateEnabled {
+ changed = append(changed, "affiliate_enabled")
+ }
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 86074df7..051fab18 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -192,6 +192,9 @@ type SystemSettings struct {
// Available Channels feature switch (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
type DefaultSubscriptionSetting struct {
@@ -244,6 +247,8 @@ type PublicSettings struct {
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index aee9d927..13e3ac88 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -34,6 +34,7 @@ type AdminHandlers struct {
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
Payment *admin.PaymentHandler
+ Affiliate *admin.AffiliateHandler
}
// Handlers contains all HTTP handlers
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 96964de4..22f2aa15 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -75,5 +75,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
})
}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 6d175488..a8725875 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -37,6 +37,7 @@ func ProvideAdminHandlers(
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
+ affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -67,6 +68,7 @@ func ProvideAdminHandlers(
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
Payment: paymentHandler,
+ Affiliate: affiliateHandler,
}
}
@@ -169,6 +171,7 @@ var ProviderSet = wire.NewSet(
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
+ admin.NewAffiliateHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go
index 342ddf4f..e3dd56b8 100644
--- a/backend/internal/repository/affiliate_repo.go
+++ b/backend/internal/repository/affiliate_repo.go
@@ -294,6 +294,8 @@ func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, us
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
@@ -315,9 +317,12 @@ WHERE user_id = $1`, userID)
var out service.AffiliateSummary
var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
@@ -330,6 +335,10 @@ WHERE user_id = $1`, userID)
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
return &out, nil
}
@@ -337,6 +346,8 @@ func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
@@ -360,9 +371,12 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
var out service.AffiliateSummary
var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
@@ -375,6 +389,10 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
return &out, nil
}
@@ -418,3 +436,229 @@ func isAffiliateUniqueViolation(err error) bool {
}
return false
}
+
+// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
+// 唯一性冲突返回 ErrAffiliateCodeTaken。
+func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ code := strings.ToUpper(strings.TrimSpace(newCode))
+ if code == "" {
+ return service.ErrAffiliateCodeInvalid
+ }
+
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = true,
+ updated_at = NOW()
+WHERE user_id = $2`, code, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ return service.ErrAffiliateCodeTaken
+ }
+ return fmt.Errorf("update aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
+func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if userID <= 0 {
+ return "", service.ErrUserNotFound
+ }
+ var newCode string
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ candidate, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return codeErr
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = false,
+ updated_at = NOW()
+WHERE user_id = $2`, candidate, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ continue
+ }
+ return fmt.Errorf("reset aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ newCode = candidate
+ return nil
+ }
+ return fmt.Errorf("reset aff_code: exhausted attempts")
+ })
+ if err != nil {
+ return "", err
+ }
+ return newCode, nil
+}
+
+// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
+func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ // nullableArg lets us use a single UPDATE for both "set value" and
+ // "clear" cases — database/sql converts nil interface{} to SQL NULL.
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = $2`, nullableArg(ratePercent), userID)
+ if err != nil {
+ return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。
+func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if len(userIDs) == 0 {
+ return nil
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ for _, uid := range userIDs {
+ if uid <= 0 {
+ continue
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
+ return err
+ }
+ }
+ _, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
+ if err != nil {
+ return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
+ }
+ return nil
+ })
+}
+
+// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
+// binding: nil pointer → SQL NULL, non-nil → the float value.
+func nullableArg(v *float64) any {
+ if v == nil {
+ return nil
+ }
+ return *v
+}
+
+// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
+//
+// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
+// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
+// 这避免了为两种情况维护两份 SQL 模板。
+func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
+ page := filter.Page
+ if page < 1 {
+ page = 1
+ }
+ pageSize := filter.PageSize
+ if pageSize <= 0 || pageSize > 200 {
+ pageSize = 20
+ }
+ offset := (page - 1) * pageSize
+ likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
+
+ const baseFrom = `
+FROM user_affiliates ua
+JOIN users u ON u.id = ua.user_id
+WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
+ AND (u.email ILIKE $1 OR u.username ILIKE $1)`
+
+ client := clientFromContext(ctx, r.client)
+
+ total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
+ if err != nil {
+ return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
+ }
+
+ listQuery := `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.aff_code,
+ ua.aff_code_custom,
+ ua.aff_rebate_rate_percent,
+ ua.aff_count` + baseFrom + `
+ORDER BY ua.updated_at DESC
+LIMIT $2 OFFSET $3`
+
+ rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ entries := make([]service.AffiliateAdminEntry, 0)
+ for rows.Next() {
+ var e service.AffiliateAdminEntry
+ var rebate sql.NullFloat64
+ if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
+ &e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
+ return nil, 0, err
+ }
+ if rebate.Valid {
+ v := rebate.Float64
+ e.AffRebateRatePercent = &v
+ }
+ entries = append(entries, e)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, 0, err
+ }
+ return entries, total, nil
+}
+
+// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
+func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
+ rows, err := client.QueryContext(ctx, query, args...)
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, nil
+ }
+ var v int64
+ if err := rows.Scan(&v); err != nil {
+ return 0, err
+ }
+ return v, nil
+}
diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go
index 3fa84426..369f57cf 100644
--- a/backend/internal/repository/affiliate_repo_integration_test.go
+++ b/backend/internal/repository/affiliate_repo_integration_test.go
@@ -182,3 +182,218 @@ VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
require.InDelta(t, 3.21, persistedBalance, 1e-9)
}
+
+// TestAffiliateRepository_AdminCustomCode covers the success path of admin
+// invite-code rewrite + reset within a shared test transaction:
+// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
+// - the old code can no longer be found
+// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
+//
+// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
+// test because a unique-violation aborts the surrounding Postgres tx, which
+// would poison subsequent assertions in the same transaction.
+func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
+ originalCode := original.AffCode
+
+ // Rewrite to a custom code
+ customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
+
+ updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, customCode, updated.AffCode)
+ require.True(t, updated.AffCodeCustom)
+
+ // Lookup by new custom code finds the user
+ byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
+ require.NoError(t, err)
+ require.Equal(t, u.ID, byCode.UserID)
+
+ // Old system code should no longer match
+ _, err = repo.GetAffiliateByCode(txCtx, originalCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+
+ // Reset back to a fresh system code, clears custom flag
+ newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
+ require.NoError(t, err)
+ require.NotEqual(t, customCode, newSysCode)
+
+ reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, newSysCode, reset.AffCode)
+ require.False(t, reset.AffCodeCustom)
+
+ // The old custom code is now free again
+ _, err = repo.GetAffiliateByCode(txCtx, customCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+}
+
+// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
+// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
+// this test must be the only assertion and run in its own tx — production
+// callers each have their own outer tx, so this matches real behavior.
+func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ taker := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ requester := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+
+ takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
+
+ // Now requester tries to grab the same code → conflict.
+ err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
+ require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
+}
+
+// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
+// set/clear and the Batch variant including NULL semantics.
+func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u1 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ u2 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ // Set exclusive rate for u1
+ rate := 42.5
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
+
+ got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.NotNil(t, got.AffRebateRatePercent)
+ require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
+
+ // Clear exclusive rate
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
+ cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.Nil(t, cleared.AffRebateRatePercent)
+
+ // Batch set both users
+ batchRate := 15.0
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
+
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.NotNil(t, v.AffRebateRatePercent)
+ require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
+ }
+
+ // Batch clear
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.Nil(t, v.AffRebateRatePercent)
+ }
+}
+
+// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
+// only includes users with at least one override applied.
+func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ // User without any custom config — should NOT appear in the list.
+ plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
+ uPlain := mustCreateUser(t, client, &service.User{
+ Email: plainEmail, PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ _, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
+ require.NoError(t, err)
+
+ // User with a custom code — should appear.
+ uCode := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
+
+ // User with only an exclusive rate — should appear.
+ uRate := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ r := 33.3
+ require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
+
+ entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
+ Page: 1, PageSize: 100,
+ })
+ require.NoError(t, err)
+
+ // Build a quick lookup to assert per-user attributes (other tests may have
+ // inserted custom rows in the same DB; we only care about our 3).
+ byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
+ for _, e := range entries {
+ byUserID[e.UserID] = e
+ }
+
+ require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
+
+ codeEntry, ok := byUserID[uCode.ID]
+ require.True(t, ok, "custom-code user missing from list")
+ require.True(t, codeEntry.AffCodeCustom)
+ require.Nil(t, codeEntry.AffRebateRatePercent)
+
+ rateEntry, ok := byUserID[uRate.ID]
+ require.True(t, ok, "custom-rate user missing from list")
+ require.False(t, rateEntry.AffCodeCustom)
+ require.NotNil(t, rateEntry.AffRebateRatePercent)
+ require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
+
+ require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
+}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 35a6524a..39286cbf 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -775,6 +775,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
+ "affiliate_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
@@ -951,6 +952,7 @@ func TestAPIContracts(t *testing.T) {
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
+ "affiliate_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 70160f7e..1c786f50 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -91,6 +91,9 @@ func RegisterAdminRoutes(
// 渠道监控
registerChannelMonitorRoutes(admin, h)
+
+ // 邀请返利(专属用户管理)
+ registerAffiliateRoutes(admin, h)
}
}
@@ -594,3 +597,18 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
+
+// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
+func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ affiliates := admin.Group("/affiliates")
+ {
+ users := affiliates.Group("/users")
+ {
+ users.GET("", h.Admin.Affiliate.ListUsers)
+ users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
+ users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
+ users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
+ users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
+ }
+ }
+}
diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go
index fa8e2018..560b71ab 100644
--- a/backend/internal/service/affiliate_service.go
+++ b/backend/internal/service/affiliate_service.go
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"math"
- "strconv"
"strings"
"time"
@@ -15,28 +14,39 @@ import (
var (
ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
+ ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
)
const (
affiliateInviteesLimit = 100
- // affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength.
- affiliateCodeFormatLength = 12
+ // AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
+ // 12-char codes and admin-customized codes (e.g. "VIP2026").
+ AffiliateCodeMinLength = 4
+ AffiliateCodeMaxLength = 32
)
-// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used
-// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9).
+// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
+// All input passes through strings.ToUpper before validation, so lowercase from
+// users is normalized — admins may supply mixed case in their UI.
var affiliateCodeValidChar = func() [256]bool {
var tbl [256]bool
- for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") {
+ for c := byte('A'); c <= 'Z'; c++ {
tbl[c] = true
}
+ for c := byte('0'); c <= '9'; c++ {
+ tbl[c] = true
+ }
+ tbl['_'] = true
+ tbl['-'] = true
return tbl
}()
+// isValidAffiliateCodeFormat validates code format for both binding (user input)
+// and admin updates. Caller is expected to upper-case the input first.
func isValidAffiliateCodeFormat(code string) bool {
- if len(code) != affiliateCodeFormatLength {
+ if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
return false
}
for i := 0; i < len(code); i++ {
@@ -48,14 +58,16 @@ func isValidAffiliateCodeFormat(code string) bool {
}
type AffiliateSummary struct {
- UserID int64 `json:"user_id"`
- AffCode string `json:"aff_code"`
- InviterID *int64 `json:"inviter_id,omitempty"`
- AffCount int `json:"aff_count"`
- AffQuota float64 `json:"aff_quota"`
- AffHistoryQuota float64 `json:"aff_history_quota"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
}
type AffiliateInvitee struct {
@@ -72,7 +84,11 @@ type AffiliateDetail struct {
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
- Invitees []AffiliateInvitee `json:"invitees"`
+ // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
+ // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
+ // 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
+ EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
+ Invitees []AffiliateInvitee `json:"invitees"`
}
type AffiliateRepository interface {
@@ -82,24 +98,57 @@ type AffiliateRepository interface {
AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
+
+ // 管理端:用户级专属配置
+ UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
+ ResetUserAffCode(ctx context.Context, userID int64) (string, error)
+ SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
+ BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
+ ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
+}
+
+// AffiliateAdminFilter 列表筛选条件
+type AffiliateAdminFilter struct {
+ Search string
+ Page int
+ PageSize int
+}
+
+// AffiliateAdminEntry 专属用户列表条目
+type AffiliateAdminEntry struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ AffCount int `json:"aff_count"`
}
type AffiliateService struct {
repo AffiliateRepository
- settingRepo SettingRepository
+ settingService *SettingService
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCacheService *BillingCacheService
}
-func NewAffiliateService(repo AffiliateRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
+func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
return &AffiliateService{
repo: repo,
- settingRepo: settingRepo,
+ settingService: settingService,
authCacheInvalidator: authCacheInvalidator,
billingCacheService: billingCacheService,
}
}
+// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
+func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
+ if s == nil || s.settingService == nil {
+ return AffiliateEnabledDefault
+ }
+ return s.settingService.IsAffiliateEnabled(ctx)
+}
+
func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
@@ -120,13 +169,14 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
return nil, err
}
return &AffiliateDetail{
- UserID: summary.UserID,
- AffCode: summary.AffCode,
- InviterID: summary.InviterID,
- AffCount: summary.AffCount,
- AffQuota: summary.AffQuota,
- AffHistoryQuota: summary.AffHistoryQuota,
- Invitees: invitees,
+ UserID: summary.UserID,
+ AffCode: summary.AffCode,
+ InviterID: summary.InviterID,
+ AffCount: summary.AffCount,
+ AffQuota: summary.AffQuota,
+ AffHistoryQuota: summary.AffHistoryQuota,
+ EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
+ Invitees: invitees,
}, nil
}
@@ -135,12 +185,16 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
if code == "" {
return nil
}
- if !isValidAffiliateCodeFormat(code) {
- return ErrAffiliateCodeInvalid
- }
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
+ // 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
+ if !s.IsEnabled(ctx) {
+ return nil
+ }
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
if err != nil {
@@ -178,6 +232,10 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
return 0, nil
}
+ // 总开关关闭时,新充值不再产生返利
+ if !s.IsEnabled(ctx) {
+ return 0, nil
+ }
inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
if err != nil {
@@ -187,16 +245,17 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
return 0, nil
}
- rebateRatePercent := s.loadAffiliateRebateRatePercent(ctx)
+ // 加载邀请人 profile,优先使用专属比例(覆盖全局)
+ inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
+ if err != nil {
+ return 0, err
+ }
+ rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
- if _, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID); err != nil {
- return 0, err
- }
-
applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
if err != nil {
return 0, err
@@ -207,6 +266,28 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
return rebate, nil
}
+// resolveRebateRatePercent returns the inviter's exclusive rate when set,
+// otherwise the global setting value (clamped to [Min, Max]).
+func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
+ if inviter != nil && inviter.AffRebateRatePercent != nil {
+ v := *inviter.AffRebateRatePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return s.globalRebateRatePercent(ctx)
+ }
+ return clampAffiliateRebateRate(v)
+ }
+ return s.globalRebateRatePercent(ctx)
+}
+
+// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
+// returning the documented default when SettingService is unavailable.
+func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
+ if s == nil || s.settingService == nil {
+ return AffiliateRebateRateDefault
+ }
+ return s.settingService.GetAffiliateRebateRatePercent(ctx)
+}
+
func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
if s == nil || s.repo == nil {
return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
@@ -236,32 +317,6 @@ func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([
return invitees, nil
}
-func (s *AffiliateService) loadAffiliateRebateRatePercent(ctx context.Context) float64 {
- if s == nil || s.settingRepo == nil {
- return AffiliateRebateRateDefault
- }
-
- raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
- if err != nil {
- return AffiliateRebateRateDefault
- }
-
- rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
- if err != nil {
- return AffiliateRebateRateDefault
- }
- if math.IsNaN(rate) || math.IsInf(rate, 0) {
- return AffiliateRebateRateDefault
- }
- if rate < AffiliateRebateRateMin {
- return AffiliateRebateRateMin
- }
- if rate > AffiliateRebateRateMax {
- return AffiliateRebateRateMax
- }
- return rate
-}
-
func roundTo(v float64, scale int) float64 {
factor := math.Pow10(scale)
return math.Round(v*factor) / factor
@@ -312,3 +367,82 @@ func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID
}
}
}
+
+// =========================
+// Admin: 专属配置管理
+// =========================
+
+// validateExclusiveRate ensures a per-user override is finite and within
+// [Min, Max]. nil is always valid (means "clear / fall back to global").
+func validateExclusiveRate(ratePercent *float64) error {
+ if ratePercent == nil {
+ return nil
+ }
+ v := *ratePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
+ }
+ if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
+ return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
+ }
+ return nil
+}
+
+// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
+func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
+ return s.repo.UpdateUserAffCode(ctx, userID, code)
+}
+
+// AdminResetUserAffCode 重置用户邀请码为系统随机码。
+func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if s == nil || s.repo == nil {
+ return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ResetUserAffCode(ctx, userID)
+}
+
+// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
+func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
+}
+
+// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
+func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ cleaned := make([]int64, 0, len(userIDs))
+ for _, uid := range userIDs {
+ if uid > 0 {
+ cleaned = append(cleaned, uid)
+ }
+ }
+ if len(cleaned) == 0 {
+ return nil
+ }
+ return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
+}
+
+// AdminListCustomUsers 列出有专属配置的用户。
+func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
+ if s == nil || s.repo == nil {
+ return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ListUsersWithCustomSettings(ctx, filter)
+}
diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go
index 605fe00f..c02a4dd7 100644
--- a/backend/internal/service/affiliate_service_test.go
+++ b/backend/internal/service/affiliate_service_test.go
@@ -4,51 +4,82 @@ package service
import (
"context"
+ "math"
"testing"
"github.com/stretchr/testify/require"
)
-type affiliateSettingRepoStub struct {
- value string
- err error
-}
-
-func (s *affiliateSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, s.err }
-func (s *affiliateSettingRepoStub) GetValue(context.Context, string) (string, error) {
- if s.err != nil {
- return "", s.err
- }
- return s.value, nil
-}
-func (s *affiliateSettingRepoStub) Set(context.Context, string, string) error { return s.err }
-func (s *affiliateSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
- if s.err != nil {
- return nil, s.err
- }
- return map[string]string{}, nil
-}
-func (s *affiliateSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
- return s.err
-}
-func (s *affiliateSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
- if s.err != nil {
- return nil, s.err
- }
- return map[string]string{}, nil
-}
-func (s *affiliateSettingRepoStub) Delete(context.Context, string) error { return s.err }
-
-func TestAffiliateRebateRatePercentSemantics(t *testing.T) {
+// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
+// AffRebateRatePercent overrides the global rate, that NULL falls back to the
+// global rate, and that out-of-range exclusive rates are clamped silently.
+//
+// SettingService is left nil here so globalRebateRatePercent returns the
+// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
+// fallback path without spinning up a settings stub.
+func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
t.Parallel()
+ svc := &AffiliateService{}
- svc := &AffiliateService{settingRepo: &affiliateSettingRepoStub{value: "1"}}
- rate := svc.loadAffiliateRebateRatePercent(context.Background())
- require.Equal(t, 1.0, rate)
+ // nil exclusive rate → falls back to global default (20%)
+ require.InDelta(t, AffiliateRebateRateDefault,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
- svc.settingRepo = &affiliateSettingRepoStub{value: "0.2"}
- rate = svc.loadAffiliateRebateRatePercent(context.Background())
- require.Equal(t, 0.2, rate)
+ // exclusive rate set → overrides global
+ rate := 50.0
+ require.InDelta(t, 50.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
+
+ // exclusive rate 0 → returns 0 (no rebate, intentional)
+ zero := 0.0
+ require.InDelta(t, 0.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
+
+ // exclusive rate above max → clamped to Max
+ tooHigh := 250.0
+ require.InDelta(t, AffiliateRebateRateMax,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
+
+ // exclusive rate below min → clamped to Min
+ tooLow := -5.0
+ require.InDelta(t, AffiliateRebateRateMin,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
+}
+
+// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
+// safely handles a nil settingService dependency by returning the default
+// (off). This protects callers from nil-pointer crashes in misconfigured
+// environments.
+func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
+ t.Parallel()
+ svc := &AffiliateService{}
+ require.False(t, svc.IsEnabled(context.Background()))
+ require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
+}
+
+// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
+// admin-facing rate setters: nil is always valid (clear), in-range values
+// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
+func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
+ t.Parallel()
+ require.NoError(t, validateExclusiveRate(nil))
+
+ for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
+ v := v
+ require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
+ }
+
+ for _, v := range []float64{-0.01, 100.01, -100, 200} {
+ v := v
+ require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
+ }
+
+ nan := math.NaN()
+ require.Error(t, validateExclusiveRate(&nan))
+ posInf := math.Inf(1)
+ require.Error(t, validateExclusiveRate(&posInf))
+ negInf := math.Inf(-1)
+ require.Error(t, validateExclusiveRate(&negInf))
}
func TestMaskEmail(t *testing.T) {
@@ -61,24 +92,33 @@ func TestMaskEmail(t *testing.T) {
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
+ // 邀请码格式校验同时服务于:
+ // 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
+ // 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
+ // 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
cases := []struct {
name string
in string
want bool
}{
- {"valid canonical", "ABCDEFGHJKLM", true},
+ {"valid canonical 12-char", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
- {"too short", "ABCDEFGHJKL", false},
- {"too long", "ABCDEFGHJKLMN", false},
- {"contains excluded letter I", "IBCDEFGHJKLM", false},
- {"contains excluded letter O", "OBCDEFGHJKLM", false},
- {"contains excluded digit 0", "0BCDEFGHJKLM", false},
- {"contains excluded digit 1", "1BCDEFGHJKLM", false},
+ {"valid admin custom short", "VIP1", true},
+ {"valid admin custom with hyphen", "NEW-USER", true},
+ {"valid admin custom with underscore", "VIP_2026", true},
+ {"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
+ // Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
+ {"letter I now allowed", "IBCDEFGHJKLM", true},
+ {"letter O now allowed", "OBCDEFGHJKLM", true},
+ {"digit 0 now allowed", "0BCDEFGHJKLM", true},
+ {"digit 1 now allowed", "1BCDEFGHJKLM", true},
+ {"too short (3 chars)", "ABC", false},
+ {"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
- {"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset
- {"ascii punctuation", "ABCDEFGHJK.M", false},
+ {"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
+ {"ascii punctuation .", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 23afeb87..04037987 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -23,6 +23,7 @@ const (
AffiliateRebateRateDefault = 20.0
AffiliateRebateRateMin = 0.0
AffiliateRebateRateMax = 100.0
+ AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
)
// Platform constants
@@ -94,6 +95,7 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
+ SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
// 邮件服务设置
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index f3801c48..f871ee85 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -454,6 +454,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyChannelMonitorEnabled,
SettingKeyChannelMonitorDefaultIntervalSeconds,
SettingKeyAvailableChannelsEnabled,
+ SettingKeyAffiliateEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -541,6 +542,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
+
+ AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
}, nil
}
@@ -687,6 +690,7 @@ type PublicSettingsInjectionPayload struct {
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
@@ -739,6 +743,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+ AffiliateEnabled: settings.AffiliateEnabled,
}, nil
}
@@ -1205,6 +1210,9 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
// Available channels feature switch
updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled)
+ // Affiliate (邀请返利) feature switch
+ updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
+
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -1480,6 +1488,30 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
+// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
+func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
+ if err != nil {
+ return false // 默认关闭
+ }
+ return value == "true"
+}
+
+// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。
+// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错,
+// 调用方只关心一个可用的数值。
+func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
+ if err != nil {
+ return AffiliateRebateRateDefault
+ }
+ rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
+ if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) {
+ return AffiliateRebateRateDefault
+ }
+ return clampAffiliateRebateRate(rate)
+}
+
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -1771,6 +1803,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Available channels feature (default disabled; opt-in)
SettingKeyAvailableChannelsEnabled: "false",
+ // Affiliate (邀请返利) feature (default disabled; opt-in)
+ SettingKeyAffiliateEnabled: "false",
+
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
@@ -2091,6 +2126,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Available channels feature (default: disabled; strict true)
result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true"
+ // Affiliate (邀请返利) feature (default: disabled; strict true)
+ result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
+
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 8a3bd421..70d8efc3 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -106,6 +106,7 @@ type SystemSettings struct {
DefaultConcurrency int
DefaultBalance float64
+ AffiliateEnabled bool
AffiliateRebateRate float64
DefaultUserRPMLimit int
DefaultSubscriptions []DefaultSubscriptionSetting
@@ -225,6 +226,9 @@ type PublicSettings struct {
// Available Channels feature (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature toggle
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
type WeChatConnectOAuthConfig struct {
diff --git a/backend/migrations/132_affiliate_custom_settings.sql b/backend/migrations/132_affiliate_custom_settings.sql
new file mode 100644
index 00000000..840fe8e0
--- /dev/null
+++ b/backend/migrations/132_affiliate_custom_settings.sql
@@ -0,0 +1,16 @@
+-- 邀请返利:用户专属配置增强
+-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例(百分比,NULL 表示沿用全局比例)
+-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选)
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2);
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false;
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings
+ ON user_affiliates (updated_at)
+ WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL;
+
+COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100,NULL 表示沿用全局)';
+COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)';
diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts
new file mode 100644
index 00000000..22639bd2
--- /dev/null
+++ b/frontend/src/api/admin/affiliates.ts
@@ -0,0 +1,108 @@
+/**
+ * Admin Affiliate API endpoints
+ * Manage per-user affiliate (邀请返利) configurations:
+ * exclusive invite codes (overrides aff_code) and exclusive rebate rates.
+ */
+
+import { apiClient } from '../client'
+import type { PaginatedResponse } from '@/types'
+
+export interface AffiliateAdminEntry {
+ user_id: number
+ email: string
+ username: string
+ aff_code: string
+ aff_code_custom: boolean
+ aff_rebate_rate_percent?: number | null
+ aff_count: number
+}
+
+export interface ListAffiliateUsersParams {
+ page?: number
+ page_size?: number
+ search?: string
+}
+
+export interface UpdateAffiliateUserRequest {
+ aff_code?: string
+ aff_rebate_rate_percent?: number | null
+ /** Set true to explicitly clear the per-user rate (sets it to NULL). */
+ clear_rebate_rate?: boolean
+}
+
+export interface BatchSetRateRequest {
+ user_ids: number[]
+ aff_rebate_rate_percent?: number | null
+ /** Set true to clear rates instead of setting. */
+ clear?: boolean
+}
+
+export interface SimpleUser {
+ id: number
+ email: string
+ username: string
+}
+
+export async function listUsers(
+ params: ListAffiliateUsersParams = {},
+): Promise> {
+ const { data } = await apiClient.get>(
+ '/admin/affiliates/users',
+ {
+ params: {
+ page: params.page ?? 1,
+ page_size: params.page_size ?? 20,
+ search: params.search ?? '',
+ },
+ },
+ )
+ return data
+}
+
+export async function lookupUsers(q: string): Promise {
+ const { data } = await apiClient.get(
+ '/admin/affiliates/users/lookup',
+ { params: { q } },
+ )
+ return data
+}
+
+export async function updateUserSettings(
+ userId: number,
+ payload: UpdateAffiliateUserRequest,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.put<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ payload,
+ )
+ return data
+}
+
+export async function clearUserSettings(
+ userId: number,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.delete<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ )
+ return data
+}
+
+export async function batchSetRate(
+ payload: BatchSetRateRequest,
+): Promise<{ affected: number }> {
+ const { data } = await apiClient.post<{ affected: number }>(
+ '/admin/affiliates/users/batch-rate',
+ payload,
+ )
+ return data
+}
+
+export const affiliatesAPI = {
+ listUsers,
+ lookupUsers,
+ updateUserSettings,
+ clearUserSettings,
+ batchSetRate,
+}
+
+export default affiliatesAPI
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 9cda5814..80241794 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -29,6 +29,7 @@ import channelsAPI from './channels'
import channelMonitorAPI from './channelMonitor'
import channelMonitorTemplateAPI from './channelMonitorTemplate'
import adminPaymentAPI from './payment'
+import affiliatesAPI from './affiliates'
/**
* Unified admin API object for convenient access
@@ -59,7 +60,8 @@ export const adminAPI = {
channels: channelsAPI,
channelMonitor: channelMonitorAPI,
channelMonitorTemplate: channelMonitorTemplateAPI,
- payment: adminPaymentAPI
+ payment: adminPaymentAPI,
+ affiliates: affiliatesAPI
}
export {
@@ -88,7 +90,8 @@ export {
channelsAPI,
channelMonitorAPI,
channelMonitorTemplateAPI,
- adminPaymentAPI
+ adminPaymentAPI,
+ affiliatesAPI
}
export default adminAPI
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 971c2314..0d98c9e9 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -478,6 +478,9 @@ export interface SystemSettings {
// Available Channels feature switch
available_channels_enabled: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled: boolean;
}
export interface UpdateSettingsRequest {
@@ -636,6 +639,9 @@ export interface UpdateSettingsRequest {
// Available Channels feature switch
available_channels_enabled?: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled?: boolean;
}
/**
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index a3a8c30e..d8e2794e 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -634,6 +634,7 @@ const ChevronDownIcon = {
const flagChannelMonitor = makeSidebarFlag(FeatureFlags.channelMonitor)
const flagPayment = makeSidebarFlag(FeatureFlags.payment)
const flagAvailableChannels = makeSidebarFlag(FeatureFlags.availableChannels)
+const flagAffiliate = makeSidebarFlag(FeatureFlags.affiliate)
const flagOpsMonitoring = () => adminSettingsStore.opsMonitoringEnabled
const flagAdminPayment = () => adminSettingsStore.paymentEnabled
@@ -656,7 +657,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] {
{ path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
- { path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true },
+ { path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true, featureFlag: flagAffiliate },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
path: `/custom/${item.id}`,
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 5aa63e6a..42d68b70 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -985,6 +985,8 @@ export default {
loadFailed: 'Failed to load affiliate data',
transferFailed: 'Failed to transfer affiliate quota',
stats: {
+ rebateRate: 'My Rebate Rate',
+ rebateRateHint: 'What you earn each time an invitee recharges',
invitedUsers: 'Invited Users',
availableQuota: 'Available Rebate Quota',
totalQuota: 'Historical Rebate Quota'
@@ -1009,7 +1011,7 @@ export default {
tips: {
title: 'How It Works',
line1: 'Share your affiliate code or invite link with new users.',
- line2: 'When invitees recharge, you receive rebate quota based on the configured rate.',
+ line2: 'When invitees recharge, you receive {rate} of the recharge as rebate quota.',
line3: 'Transfer rebate quota to balance at any time.'
}
},
@@ -4779,6 +4781,55 @@ export default {
enabled: 'Enable Available Channels',
enabledHint: 'When off, the sidebar entry is hidden and the endpoint returns an empty list.',
},
+ affiliate: {
+ title: 'Affiliate (Invite Rebate)',
+ description: 'Existing users invite new ones; the inviter earns a percentage rebate on the invitee’s recharges. Disabled by default.',
+ enabled: 'Enable Affiliate',
+ enabledHint: 'When off, the affiliate menu is hidden, the aff parameter is ignored at signup, and new recharges generate no rebate. Existing rebate balances can still be transferred.',
+ rebateRate: 'Global Rebate Rate',
+ rebateRateHint: 'Default percentage given back to the inviter on recharges (0-100, e.g. 10 = 10%).',
+ customUsers: {
+ title: 'Per-User Overrides',
+ description: 'Set a custom invite code or exclusive rebate rate for specific users. Lists only users that have an override applied.',
+ addButton: 'Add Custom User',
+ searchPlaceholder: 'Search by email or username',
+ batchButton: 'Batch Set Rate ({count} selected)',
+ empty: 'No users with custom affiliate settings yet',
+ customBadge: 'custom',
+ useGlobal: 'use global',
+ resetTitle: 'Reset Custom Settings',
+ resetMessage: 'Reset all custom settings for {email}?\n• The exclusive rebate rate will be cleared (fall back to the global rate)\n• The invite code will be regenerated as a new system code (previously shared links will stop working)',
+ totalLabel: '{total} total',
+ col: {
+ email: 'Email',
+ username: 'Username',
+ code: 'Invite Code',
+ rate: 'Custom Rate',
+ actions: 'Actions',
+ },
+ },
+ modal: {
+ addTitle: 'Add Custom User',
+ editTitle: 'Edit Custom Settings',
+ userLabel: 'User',
+ userPlaceholder: 'Search by email or username',
+ changeUser: 'Change user',
+ codeLabel: 'Custom Invite Code (optional)',
+ codePlaceholder: 'e.g. VIP2026',
+ codeHint: '4-32 characters; A-Z, 0-9, underscore, dash. Leave empty to keep current. Input is upper-cased.',
+ rateLabel: 'Exclusive Rebate Rate (optional)',
+ ratePlaceholder: 'e.g. 30',
+ rateHint: '0-100. Leave empty (in edit mode) to clear and fall back to the global rate.',
+ errorBadRate: 'Please enter a number between 0 and 100',
+ errorEmpty: 'Fill at least one: custom invite code or exclusive rebate rate',
+ },
+ batchModal: {
+ title: 'Batch Set Rate ({count} users selected)',
+ hint: 'Apply the same exclusive rebate rate to all selected users.',
+ placeholder: 'e.g. 30',
+ clearHint: 'Submitting empty will clear the exclusive rate for selected users.',
+ },
+ },
},
emailTabDisabledTitle: 'Email Verification Not Enabled',
emailTabDisabledHint: 'Enable email verification in the Security tab to configure SMTP settings.',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index b61248ff..7601d01c 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -989,6 +989,8 @@ export default {
loadFailed: '加载邀请返利数据失败',
transferFailed: '转入余额失败',
stats: {
+ rebateRate: '我的返利比例',
+ rebateRateHint: '被邀请用户每次充值后你可获得的返利比例',
invitedUsers: '邀请人数',
availableQuota: '可转返利额度',
totalQuota: '历史返利额度'
@@ -1013,7 +1015,7 @@ export default {
tips: {
title: '使用说明',
line1: '将邀请码或邀请链接分享给新用户。',
- line2: '被邀请用户充值后,你可获得对应比例的返利额度。',
+ line2: '被邀请用户充值后,你可获得 {rate} 的返利额度。',
line3: '返利额度可随时转入账户余额。'
}
},
@@ -4942,6 +4944,55 @@ export default {
enabled: '启用可用渠道',
enabledHint: '关闭后用户端侧边栏入口隐藏,接口返回空数组。',
},
+ affiliate: {
+ title: '邀请返利',
+ description: '老用户邀请新用户注册,新用户充值后老用户按比例获得返利额度。默认关闭。',
+ enabled: '启用邀请返利',
+ enabledHint: '关闭后用户菜单中的邀请页面入口隐藏、注册时忽略邀请码、新充值不再产生返利。已有返利额度仍可转入余额。',
+ rebateRate: '全局返利比例',
+ rebateRateHint: '充值后返给邀请人的默认比例(0-100%,例如填写 10 表示返利 10%)。',
+ customUsers: {
+ title: '专属用户配置',
+ description: '为指定用户设置专属邀请码或专属返利比例。仅展示已设置过专属配置的用户。',
+ addButton: '添加专属用户',
+ searchPlaceholder: '搜索邮箱或用户名',
+ batchButton: '批量设置比例(已选 {count})',
+ empty: '暂无专属配置用户',
+ customBadge: '自定义',
+ useGlobal: '沿用全局',
+ resetTitle: '重置该用户的专属配置',
+ resetMessage: '确认将 {email} 的专属配置全部重置为默认?\n• 专属返利比例将清除(沿用全局)\n• 邀请码将重新生成为系统随机码(已分发的旧邀请链接将失效)',
+ totalLabel: '共 {total} 条',
+ col: {
+ email: '邮箱',
+ username: '用户名',
+ code: '邀请码',
+ rate: '专属比例',
+ actions: '操作',
+ },
+ },
+ modal: {
+ addTitle: '添加专属用户',
+ editTitle: '编辑专属配置',
+ userLabel: '用户',
+ userPlaceholder: '搜索邮箱或用户名',
+ changeUser: '更换用户',
+ codeLabel: '专属邀请码(可选)',
+ codePlaceholder: '例如 VIP2026',
+ codeHint: '4-32 位,仅支持大写字母、数字、下划线、连字符;留空表示不修改;输入将自动转大写。',
+ rateLabel: '专属返利比例(可选)',
+ ratePlaceholder: '例如 30',
+ rateHint: '0-100%;留空(编辑模式下)表示清除专属比例并沿用全局。',
+ errorBadRate: '请输入 0-100 之间的比例',
+ errorEmpty: '至少填写一项:专属邀请码或专属返利比例',
+ },
+ batchModal: {
+ title: '批量设置专属比例(已选 {count} 个用户)',
+ hint: '为所选用户统一设置专属返利比例。',
+ placeholder: '例如 30',
+ clearHint: '留空提交将清除所选用户的专属比例。',
+ },
+ },
},
emailTabDisabledTitle: '邮箱验证未启用',
emailTabDisabledHint: '请在「安全与认证」选项卡中启用邮箱验证后,再配置 SMTP 设置。',
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index e6cd9eff..876ab5c0 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -355,6 +355,7 @@ export const useAppStore = defineStore('app', () => {
channel_monitor_enabled: true,
channel_monitor_default_interval_seconds: 60,
available_channels_enabled: false,
+ affiliate_enabled: false,
}
}
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 50b4353e..2a15ad00 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -139,6 +139,8 @@ export interface UserAffiliateDetail {
aff_count: number
aff_quota: number
aff_history_quota: number
+ /** 当前用户作为邀请人时实际生效的返利比例(专属覆盖全局)。0-100。 */
+ effective_rebate_rate_percent: number
invitees: AffiliateInvitee[]
}
@@ -212,6 +214,7 @@ export interface PublicSettings {
channel_monitor_enabled: boolean
channel_monitor_default_interval_seconds: number
available_channels_enabled: boolean
+ affiliate_enabled: boolean
}
export interface AuthResponse {
diff --git a/frontend/src/utils/featureFlags.ts b/frontend/src/utils/featureFlags.ts
index 51b043cc..e0668694 100644
--- a/frontend/src/utils/featureFlags.ts
+++ b/frontend/src/utils/featureFlags.ts
@@ -109,6 +109,11 @@ export const FeatureFlags = {
mode: 'opt-out',
label: 'Payment',
}),
+ affiliate: defineFlag({
+ key: 'affiliate_enabled',
+ mode: 'opt-in',
+ label: 'Affiliate',
+ }),
} as const
export type RegisteredFeatureFlag = keyof typeof FeatureFlags
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 6da4b21a..87113e59 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2153,31 +2153,6 @@
{{ t("admin.settings.defaults.defaultBalanceHint") }}
-
-
-
-
- %
-
-
- {{ t("admin.settings.defaults.affiliateRebateRateHint") }}
-
-
+
+
+
+
+ {{ t('admin.settings.features.affiliate.title') }}
+
+
+ {{ t('admin.settings.features.affiliate.description') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.enabledHint') }}
+
+
+
+
+
+
+
+
+
+
+ %
+
+
+ {{ t('admin.settings.features.affiliate.rebateRateHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.customUsers.title') }}
+
+
+ {{ t('admin.settings.features.affiliate.customUsers.description') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.customUsers.totalLabel', { total: affiliateState.total }) }}
+
+
+
+ {{ affiliateState.page }} / {{ Math.max(1, Math.ceil(affiliateState.total / affiliateState.pageSize)) }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ affiliateModal.mode === 'add' ? t('admin.settings.features.affiliate.modal.addTitle') : t('admin.settings.features.affiliate.modal.editTitle') }}
+
+
+
+
+
+
+
+ {{ affiliateModal.selectedUser.email }}
+ ({{ affiliateModal.selectedUser.username }})
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.modal.codeHint') }}
+
+
+
+
+
+
+
+ %
+
+
+ {{ t('admin.settings.features.affiliate.modal.rateHint') }}
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.modal.errorEmpty') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.features.affiliate.batchModal.title', { count: affiliateState.selected.length }) }}
+
+
+ {{ t('admin.settings.features.affiliate.batchModal.hint') }}
+
+
+
+ %
+
+
+ {{ t('admin.settings.features.affiliate.batchModal.clearHint') }}
+
+
+
+
+
+
+
+
@@ -4793,12 +5118,21 @@
@confirm="handleDeleteProvider"
@cancel="showDeleteProviderDialog = false"
/>
+