From f68909a68ba31008f1d30f9f42069b6326f36525 Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Tue, 21 Apr 2026 08:54:18 +0800 Subject: [PATCH 01/33] fix: reconcile openai admin test rate-limit state --- .../internal/service/account_test_service.go | 36 +++++++++++++++++++ .../account_test_service_openai_test.go | 19 ++++++---- backend/internal/service/ratelimit_service.go | 6 +++- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index e5bc93ca..bb2fb8c0 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -538,6 +538,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusTooManyRequests { + s.reconcileOpenAI429State(ctx, account, resp.Header, body) + } // 401 Unauthorized: 标记账号为永久错误 if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) @@ -550,6 +553,39 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.processOpenAIStream(c, resp.Body) } +func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) { + if s == nil || s.accountRepo == nil || account == nil { + return + } + + var resetAt *time.Time + if calculated := calculateOpenAI429ResetTime(headers); calculated != nil { + resetAt = calculated + } else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil { + t := time.Unix(*unixTs, 0) + resetAt = &t + } + if resetAt == nil { + return + } + + if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { + return + } + + now := time.Now() + account.RateLimitedAt = &now + account.RateLimitResetAt = resetAt + + if account.Status == StatusError { + if err := s.accountRepo.ClearError(ctx, account.ID); err != nil { + return + } + account.Status = StatusActive + account.ErrorMessage = "" + } +} + // testGeminiAccountConnection tests a Gemini account's connection func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 82ff0a8b..213ef52c 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -61,9 +61,10 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) { type openAIAccountTestRepo struct { mockAccountRepoForGemini - updatedExtra map[string]any + updatedExtra map[string]any rateLimitedID int64 rateLimitedAt *time.Time + clearedErrorID int64 } func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { @@ -77,6 +78,11 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese return nil } +func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error { + r.clearedErrorID = id + return nil +} + func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { gin.SetMode(gin.TestMode) ctx, recorder := newTestContext() @@ -111,11 +117,11 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. require.Contains(t, recorder.Body.String(), "test_complete") } -func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) { +func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) { gin.SetMode(gin.TestMode) ctx, _ := newTestContext() - resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`) resp.Header.Set("x-codex-primary-used-percent", "100") resp.Header.Set("x-codex-primary-reset-after-seconds", "604800") resp.Header.Set("x-codex-primary-window-minutes", "10080") @@ -130,6 +136,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing ID: 88, Platform: PlatformOpenAI, Type: AccountTypeOAuth, + Status: StatusError, Concurrency: 1, Credentials: map[string]any{"access_token": "test-token"}, } @@ -138,7 +145,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) - require.Zero(t, repo.rateLimitedID) - require.Nil(t, repo.rateLimitedAt) - require.Nil(t, account.RateLimitResetAt) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Equal(t, account.ID, repo.clearedErrorID) } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 4730303f..9344de47 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -931,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间 // 返回 nil 表示无法从响应头中确定重置时间 -func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { +func calculateOpenAI429ResetTime(headers http.Header) *time.Time { snapshot := ParseCodexRateLimitHeaders(headers) if snapshot == nil { return nil @@ -977,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { + return calculateOpenAI429ResetTime(headers) +} + // anthropic429Result holds the parsed Anthropic 429 rate-limit information. type anthropic429Result struct { resetAt time.Time // The correct reset time to use for SetRateLimited From 5fc30ea964c47bac64b4c3dbde5e4c4180eb5078 Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Tue, 21 Apr 2026 09:03:25 +0800 Subject: [PATCH 02/33] test: cover openai admin test state transitions --- .../account_test_service_openai_test.go | 130 +++++++++++++++++- 1 file changed, 127 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 213ef52c..12d8128a 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -61,10 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) { type openAIAccountTestRepo struct { mockAccountRepoForGemini - updatedExtra map[string]any - rateLimitedID int64 - rateLimitedAt *time.Time + updatedExtra map[string]any + rateLimitedID int64 + rateLimitedAt *time.Time clearedErrorID int64 + setErrorID int64 + setErrorMsg string } func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { @@ -83,6 +85,12 @@ func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error { return nil } +func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error { + r.setErrorID = id + r.setErrorMsg = errorMsg + return nil +} + func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { gin.SetMode(gin.TestMode) ctx, recorder := newTestContext() @@ -148,4 +156,120 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testin require.Equal(t, account.ID, repo.rateLimitedID) require.NotNil(t, repo.rateLimitedAt) require.Equal(t, account.ID, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.Empty(t, account.ErrorMessage) + require.NotNil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 77, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions", + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Equal(t, account.ID, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.Empty(t, account.ErrorMessage) + require.NotNil(t, account.RateLimitResetAt) + require.Empty(t, repo.updatedExtra) +} + +func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 78, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.Equal(t, account.ID, repo.rateLimitedID) + require.NotNil(t, repo.rateLimitedAt) + require.Zero(t, repo.clearedErrorID) + require.Equal(t, StatusActive, account.Status) + require.NotNil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 79, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "stale 403", + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.Zero(t, repo.rateLimitedID) + require.Nil(t, repo.rateLimitedAt) + require.Zero(t, repo.clearedErrorID) + require.Equal(t, StatusError, account.Status) + require.Equal(t, "stale 403", account.ErrorMessage) + require.Nil(t, account.RateLimitResetAt) +} + +func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`) + + repo := &openAIAccountTestRepo{} + upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}} + svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream} + account := &Account{ + ID: 80, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{"access_token": "test-token"}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + require.Error(t, err) + require.Equal(t, account.ID, repo.setErrorID) + require.Contains(t, repo.setErrorMsg, "Authentication failed (401)") + require.Zero(t, repo.rateLimitedID) + require.Zero(t, repo.clearedErrorID) + require.Nil(t, account.RateLimitResetAt) } From d80469ea35d6dc085e9e4aee43834280c3a18b5d Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Thu, 23 Apr 2026 18:15:00 +0800 Subject: [PATCH 03/33] test: fix OpenAI account test helper calls after rebase --- .../internal/service/account_test_service_openai_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 12d8128a..c1e42b4f 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -180,7 +180,7 @@ func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleErro Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.Equal(t, account.ID, repo.rateLimitedID) require.NotNil(t, repo.rateLimitedAt) @@ -209,7 +209,7 @@ func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.Equal(t, account.ID, repo.rateLimitedID) require.NotNil(t, repo.rateLimitedAt) @@ -237,7 +237,7 @@ func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.Zero(t, repo.rateLimitedID) require.Nil(t, repo.rateLimitedAt) @@ -265,7 +265,7 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) { Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.Equal(t, account.ID, repo.setErrorID) require.Contains(t, repo.setErrorMsg, "Authentication failed (401)") From f3ea878ba21297f9237142d20c6a83f698cbabbb Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Thu, 23 Apr 2026 18:33:27 +0800 Subject: [PATCH 04/33] chore: trigger PR checks From c4d496da18db6885b7d22d449451b1dbf518b73d Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Fri, 24 Apr 2026 07:42:31 +0000 Subject: [PATCH 05/33] fix(openai): handle codex spark model limitations --- .../service/openai_codex_transform.go | 66 +++++++++++ .../service/openai_codex_transform_test.go | 111 ++++++++++++++++++ .../service/openai_gateway_service.go | 11 ++ .../internal/service/openai_model_mapping.go | 24 +++- .../service/openai_model_mapping_test.go | 47 +++++++- 5 files changed, 257 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 14abde9b..f9c0de72 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -48,6 +48,8 @@ type codexTransformResult struct { const ( codexImageGenerationBridgeMarker = "" codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n" + codexSparkImageUnsupportedMarker = "" + codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n" ) func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { @@ -165,6 +167,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if applyInstructions(reqBody, isCodexCLI) { result.Modified = true } + if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) { + result.Modified = true + } // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 if input, ok := reqBody["input"].([]any); ok { @@ -244,6 +249,10 @@ func normalizeCodexModel(model string) string { return "gpt-5.4" } +func isCodexSparkModel(model string) bool { + return normalizeCodexModel(model) == "gpt-5.3-codex-spark" +} + func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -265,6 +274,40 @@ func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { return false } +func hasOpenAIInputImage(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"]) +} + +func hasOpenAIInputImageValue(value any) bool { + switch v := value.(type) { + case []any: + for _, item := range v { + if hasOpenAIInputImageValue(item) { + return true + } + } + case map[string]any: + if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" { + return true + } + if _, ok := v["image_url"]; ok { + return true + } + return hasOpenAIInputImageValue(v["content"]) + } + return false +} + +func validateCodexSparkInput(reqBody map[string]any, model string) error { + if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) { + return nil + } + return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model)) +} + func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -309,6 +352,9 @@ func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool { if len(reqBody) == 0 { return false } + if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) { + return false + } tool := map[string]any{ "type": "image_generation", @@ -344,6 +390,9 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) { return false } + if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) { + return false + } existing, _ := reqBody["instructions"].(string) if strings.Contains(existing, codexImageGenerationBridgeMarker) { @@ -360,6 +409,23 @@ func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { return true } +func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + existing, _ := reqBody["instructions"].(string) + if strings.Contains(existing, codexSparkImageUnsupportedMarker) { + return false + } + existing = strings.TrimRight(existing, " \t\r\n") + if strings.TrimSpace(existing) == "" { + reqBody["instructions"] = codexSparkImageUnsupportedText + return true + } + reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText + return true +} + func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { if !hasOpenAIImageGenerationTool(reqBody) { return nil diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 4fd16fdb..3a965d5c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -261,6 +261,17 @@ func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) { require.Equal(t, "png", tool["output_format"]) } +func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": "draw a cat", + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.False(t, modified) + require.NotContains(t, reqBody, "tools") +} + func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.4", @@ -306,6 +317,7 @@ func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t * func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) { reqBody := map[string]any{ + "model": "gpt-5.4", "instructions": "existing instructions", "tools": []any{ map[string]any{"type": "image_generation", "output_format": "png"}, @@ -325,6 +337,20 @@ func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testin require.False(t, modified) } +func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "png"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) + require.Equal(t, "existing instructions", reqBody["instructions"]) +} + func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) { reqBody := map[string]any{ "instructions": "existing instructions", @@ -338,6 +364,91 @@ func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *te require.Equal(t, "existing instructions", reqBody["instructions"]) } +func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "describe"}, + map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="}, + }, + }, + }, + } + + err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark") + require.Error(t, err) + require.Contains(t, err.Error(), "does not support image input") +} + +func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}}, + }, + }, + }, + } + + err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark") + require.Error(t, err) +} + +func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "hello"}, + }, + }, + }, + } + + require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")) +} + +func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "instructions": "existing instructions", + "input": "hello", + } + + result := applyCodexOAuthTransform(reqBody, true, false) + require.True(t, result.Modified) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Contains(t, instructions, "existing instructions") + require.Contains(t, instructions, codexSparkImageUnsupportedMarker) + require.Contains(t, instructions, "does not support image generation") + require.Contains(t, instructions, "switch to a non-Spark Codex model") + require.NotContains(t, instructions, codexImageGenerationBridgeMarker) +} + +func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "instructions": "existing instructions", + "input": "hello", + } + + applyCodexOAuthTransform(reqBody, true, false) + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotContains(t, instructions, codexSparkImageUnsupportedMarker) +} + func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) { reqBody := map[string]any{ "model": "gpt-image-2", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d99cd7da..2d05c3ea 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1995,6 +1995,17 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco account.Type, ) } + if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "input", + }, + }) + return nil, err + } // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 9bf3fba3..993c0b13 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,5 +1,7 @@ package service +import "strings" + // resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible // forwarding. Group-level default mapping only applies when the account itself // did not match any explicit model_mapping rule. @@ -12,8 +14,28 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } mappedModel, matched := account.ResolveMappedModel(requestedModel) - if !matched && defaultMappedModel != "" { + if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) { return defaultMappedModel } return mappedModel } + +func isExplicitCodexModel(model string) bool { + model = strings.TrimSpace(model) + if model == "" { + return false + } + if strings.Contains(model, "/") { + parts := strings.Split(model, "/") + model = parts[len(parts)-1] + } + model = strings.ToLower(strings.TrimSpace(model)) + if getNormalizedCodexModel(model) != "" { + return true + } + if strings.HasSuffix(model, "-openai-compact") { + base := strings.TrimSuffix(model, "-openai-compact") + return getNormalizedCodexModel(base) != "" + } + return false +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index f25863a8..21a2e9a0 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) { account: &Account{ Credentials: map[string]any{}, }, - requestedModel: "gpt-5.4", + requestedModel: "claude-opus-4-6", defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-4o-mini", }, + { + name: "preserves explicit gpt-5.4 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, { name: "preserves exact passthrough mapping instead of group default", account: &Account{ @@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-4o-mini", expectedModel: "gpt-5.4", }, + { + name: "preserves codex spark instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.3-codex-spark", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.3-codex-spark", + }, + { + name: "preserves gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.5", + }, + { + name: "preserves openai namespaced gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "openai/gpt-5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "openai/gpt-5.5", + }, + { + name: "preserves compact gpt-5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.5-openai-compact", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt-5.5-openai-compact", + }, } for _, tt := range tests { From 959af1c8f6fb120863242d2be5b62b5820b8af0e Mon Sep 17 00:00:00 2001 From: song Date: Fri, 24 Apr 2026 17:15:42 +0800 Subject: [PATCH 06/33] fix(openai): preserve codex tool call ids --- .../service/openai_codex_transform.go | 22 ++++-- .../service/openai_codex_transform_test.go | 72 +++++++++++++++++++ .../service/openai_tool_continuation.go | 4 +- .../service/openai_tool_continuation_test.go | 3 + 4 files changed, 94 insertions(+), 7 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 14abde9b..65f7f5b4 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -658,12 +658,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } } + if !isCodexToolCallItemType(typ) { + ensureCopy() + delete(newItem, "call_id") + } + if !preserveReferences { ensureCopy() delete(newItem, "id") - if !isCodexToolCallItemType(typ) { - delete(newItem, "call_id") - } } filtered = append(filtered, newItem) @@ -672,10 +674,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } func isCodexToolCallItemType(typ string) bool { - if typ == "" { + switch typ { + case "function_call", + "tool_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "function_call_output", + "mcp_tool_call_output", + "custom_tool_call_output", + "tool_search_output": + return true + default: return false } - return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output") } func normalizeCodexTools(reqBody map[string]any) bool { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 4fd16fdb..476f1ea9 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -92,6 +92,78 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly require.Equal(t, "fc1", second["call_id"]) } +func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"}, + }, + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "tool_search_output", first["type"]) + require.Equal(t, "fc1", first["call_id"]) +} + +func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"}, + map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"}, + }, + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "fccustom", first["call_id"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "fcmcp", second["call_id"]) +} + +func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.2", + "input": []any{ + map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"}, + map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"}, + }, + "tool_choice": "auto", + } + + applyCodexOAuthTransform(reqBody, false, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "ig_123", first["id"]) + _, hasCallID := first["call_id"] + require.False(t, hasCallID) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + _, hasCallID = second["call_id"] + require.False(t, hasCallID) +} + func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index dea3c172..c0f98de4 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct { } // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 -// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 +// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、 // 或显式声明 tools/tool_choice。 func NeedsToolContinuation(reqBody map[string]any) bool { if reqBody == nil { @@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool { continue } itemType, _ := itemMap["type"].(string) - if itemType == "function_call_output" || itemType == "item_reference" { + if isCodexToolCallItemType(itemType) || itemType == "item_reference" { return true } } diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go index fe737ad6..3f415d9d 100644 --- a/backend/internal/service/openai_tool_continuation_test.go +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) { {name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true}, {name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false}, {name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true}, + {name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true}, + {name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true}, + {name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true}, {name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true}, {name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true}, {name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false}, From e65574dea99d42da4681a52f10e7f9bdc5deed99 Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Fri, 24 Apr 2026 12:03:19 +0000 Subject: [PATCH 07/33] fix(openai): normalize codex responses payloads --- .../service/openai_codex_transform.go | 214 ++++++++++++++++++ .../service/openai_codex_transform_test.go | 125 ++++++++++ 2 files changed, 339 insertions(+) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 20d303b3..4903c420 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "fmt" "strings" ) @@ -153,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if normalizeCodexTools(reqBody) { result.Modified = true } + if normalizeCodexToolChoice(reqBody) { + result.Modified = true + } if v, ok := reqBody["prompt_cache_key"].(string); ok { result.PromptCacheKey = strings.TrimSpace(v) @@ -173,6 +177,14 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 if input, ok := reqBody["input"].([]any); ok { + if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified { + input = normalizedInput + result.Modified = true + } + if normalizedInput, modified := normalizeCodexMessageContentText(input); modified { + input = normalizedInput + result.Modified = true + } input = filterCodexInput(input, needsToolContinuation) reqBody["input"] = input result.Modified = true @@ -197,6 +209,183 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact return result } +func normalizeCodexToolChoice(reqBody map[string]any) bool { + choice, ok := reqBody["tool_choice"] + if !ok || choice == nil { + return false + } + choiceMap, ok := choice.(map[string]any) + if !ok { + return false + } + choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) + if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { + return false + } + reqBody["tool_choice"] = "auto" + return true +} + +func codexToolsContainType(rawTools any, toolType string) bool { + tools, ok := rawTools.([]any) + if !ok || strings.TrimSpace(toolType) == "" { + return false + } + for _, rawTool := range tools { + tool, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType { + return true + } + } + return false +} + +func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { + if len(input) == 0 { + return input, false + } + + modified := false + normalized := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + normalized = append(normalized, item) + continue + } + role, _ := m["role"].(string) + if strings.TrimSpace(role) != "tool" { + normalized = append(normalized, item) + continue + } + + callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"]) + callID = strings.TrimSpace(callID) + if callID == "" { + // Responses does not accept role:"tool". If no call id is available, + // preserve the text as a user message instead of sending invalid input. + fallback := make(map[string]any, len(m)) + for key, value := range m { + fallback[key] = value + } + fallback["role"] = "user" + delete(fallback, "tool_call_id") + normalized = append(normalized, fallback) + modified = true + continue + } + + output := extractTextFromContent(m["content"]) + if output == "" { + if value, ok := m["output"].(string); ok { + output = value + } + } + if output == "" && m["content"] != nil { + if b, err := json.Marshal(m["content"]); err == nil { + output = string(b) + } + } + + normalized = append(normalized, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": output, + }) + modified = true + } + if !modified { + return input, false + } + return normalized, true +} + +func normalizeCodexMessageContentText(input []any) ([]any, bool) { + if len(input) == 0 { + return input, false + } + + modified := false + normalized := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" { + normalized = append(normalized, item) + continue + } + parts, ok := m["content"].([]any) + if !ok { + normalized = append(normalized, item) + continue + } + + var newItem map[string]any + var newParts []any + ensureItemCopy := func() { + if newItem != nil { + return + } + newItem = make(map[string]any, len(m)) + for key, value := range m { + newItem[key] = value + } + newParts = make([]any, len(parts)) + copy(newParts, parts) + } + + for i, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + text, hasText := part["text"] + if !hasText { + continue + } + if _, ok := text.(string); ok { + continue + } + + ensureItemCopy() + newPart := make(map[string]any, len(part)) + for key, value := range part { + newPart[key] = value + } + newPart["text"] = stringifyCodexContentText(text) + newParts[i] = newPart + modified = true + } + + if newItem != nil { + newItem["content"] = newParts + normalized = append(normalized, newItem) + continue + } + normalized = append(normalized, item) + } + if !modified { + return input, false + } + return normalized, true +} + +func stringifyCodexContentText(value any) string { + switch v := value.(type) { + case string: + return v + case nil: + return "" + default: + if b, err := json.Marshal(v); err == nil { + return string(b) + } + return fmt.Sprint(v) + } +} + func normalizeCodexModel(model string) string { model = strings.TrimSpace(model) if model == "" { @@ -729,6 +918,22 @@ func filterCodexInput(input []any, preserveReferences bool) []any { delete(newItem, "call_id") } + if codexInputItemRequiresName(typ) { + if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" { + name := firstNonEmptyString(m["tool_name"]) + if name == "" { + if function, ok := m["function"].(map[string]any); ok { + name = firstNonEmptyString(function["name"]) + } + } + if name == "" { + name = "tool" + } + ensureCopy() + newItem["name"] = name + } + } + if !preserveReferences { ensureCopy() delete(newItem, "id") @@ -756,6 +961,15 @@ func isCodexToolCallItemType(typ string) bool { } } +func codexInputItemRequiresName(typ string) bool { + switch strings.TrimSpace(typ) { + case "function_call", "custom_tool_call", "mcp_tool_call": + return true + default: + return false + } +} + func normalizeCodexTools(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index f655e61c..ca9b4cea 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -164,6 +164,131 @@ func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testi require.False(t, hasCallID) } +func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "message", + "role": "tool", + "tool_call_id": "call_1", + "content": "ok", + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call_output", item["type"]) + require.Equal(t, "fc1", item["call_id"]) + require.Equal(t, "ok", item["output"]) + _, hasRole := item["role"] + require.False(t, hasRole) +} + +func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "message", + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": []any{"a", "b"}}, + }, + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + item, ok := input[0].(map[string]any) + require.True(t, ok) + content, ok := item["content"].([]any) + require.True(t, ok) + part, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, `["a","b"]`, part["text"]) +} + +func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{"type": "custom"}, + } + + applyCodexOAuthTransform(reqBody, true, false) + + require.Equal(t, "auto", reqBody["tool_choice"]) +} + +func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "custom", "name": "shell"}, + }, + "tool_choice": map[string]any{"type": "custom"}, + } + + applyCodexOAuthTransform(reqBody, true, false) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "custom", choice["type"]) +} + +func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{"type": "message", "role": "user", "content": "run tool"}, + map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + item, ok := input[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call", item["type"]) + require.Equal(t, "tool", item["name"]) + require.Equal(t, "fc1", item["call_id"]) +} + +func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "shell", item["name"]) + require.Equal(t, "fc1", item["call_id"]) +} + func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 From 27ee141c1edef7682fb19e3cf14b9433a592998e Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Fri, 24 Apr 2026 13:24:21 +0000 Subject: [PATCH 08/33] fix(openai): preserve mcp tool call ids --- .../service/openai_codex_transform.go | 1 + .../service/openai_codex_transform_test.go | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 4903c420..e765d7e9 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -951,6 +951,7 @@ func isCodexToolCallItemType(typ string) bool { "local_shell_call", "tool_search_call", "custom_tool_call", + "mcp_tool_call", "function_call_output", "mcp_tool_call_output", "custom_tool_call_output", diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index ca9b4cea..75f5c55c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -289,6 +289,38 @@ func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) { require.Equal(t, "fc1", item["call_id"]) } +func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": []any{ + map[string]any{ + "type": "mcp_tool_call", + "call_id": "call_abc", + "name": "remote_tool", + "arguments": "{}", + }, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + input, ok := reqBody["input"].([]any) + require.True(t, ok) + require.Len(t, input, 1) + item, ok := input[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "mcp_tool_call", item["type"]) + require.Equal(t, "remote_tool", item["name"]) + require.Equal(t, "fcabc", item["call_id"]) +} + +func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) { + for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} { + require.True(t, codexInputItemRequiresName(typ), typ) + require.True(t, isCodexToolCallItemType(typ), typ) + } +} + func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 From f03de00cb9a8ead97cdf4507650a746eeb66a198 Mon Sep 17 00:00:00 2001 From: VpSanta33 <1398474964@qq.com> Date: Fri, 24 Apr 2026 21:41:26 +0800 Subject: [PATCH 09/33] feat: add affiliate invite rebate flow and admin rebate-rate setting --- backend/cmd/server/wire_gen.go | 28 +- .../internal/handler/admin/setting_handler.go | 17 + backend/internal/handler/auth_handler.go | 11 +- backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/user_handler.go | 74 ++- backend/internal/handler/wire.go | 14 +- backend/internal/repository/affiliate_repo.go | 420 ++++++++++++++++++ .../affiliate_repo_integration_test.go | 114 +++++ backend/internal/repository/wire.go | 1 + backend/internal/server/api_contract_test.go | 2 + backend/internal/server/routes/user.go | 2 + backend/internal/service/affiliate_service.go | 288 ++++++++++++ .../service/affiliate_service_test.go | 59 +++ backend/internal/service/auth_service.go | 29 +- backend/internal/service/domain_constants.go | 8 + .../internal/service/payment_fulfillment.go | 136 ++++++ backend/internal/service/payment_service.go | 30 +- backend/internal/service/setting_service.go | 22 + backend/internal/service/settings_view.go | 1 + backend/internal/service/wire.go | 52 ++- .../migrations/130_add_user_affiliates.sql | 20 + .../131_affiliate_rebate_hardening.sql | 58 +++ frontend/src/api/admin/settings.ts | 2 + frontend/src/api/user.ts | 23 +- frontend/src/components/layout/AppSidebar.vue | 1 + frontend/src/i18n/locales/en.ts | 45 ++ frontend/src/i18n/locales/zh.ts | 44 ++ frontend/src/router/index.ts | 12 + frontend/src/types/index.ts | 23 + frontend/src/views/admin/SettingsView.vue | 30 ++ frontend/src/views/auth/EmailVerifyView.vue | 5 +- frontend/src/views/auth/RegisterView.vue | 13 +- frontend/src/views/user/AffiliateView.vue | 201 +++++++++ 33 files changed, 1744 insertions(+), 42 deletions(-) create mode 100644 backend/internal/repository/affiliate_repo.go create mode 100644 backend/internal/repository/affiliate_repo_integration_test.go create mode 100644 backend/internal/service/affiliate_service.go create mode 100644 backend/internal/service/affiliate_service_test.go create mode 100644 backend/migrations/130_add_user_affiliates.sql create mode 100644 backend/migrations/131_affiliate_rebate_hardening.sql create mode 100644 frontend/src/views/user/AffiliateView.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 93270e7e..f8e0dcf4 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -69,7 +69,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + affiliateRepository := repository.NewAffiliateRepository(client, db) + affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService) + authService := service.ProvideAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) @@ -80,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache) + userHandler := handler.ProvideUserHandler(userService, authService, emailService, emailCache, affiliateService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -91,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { announcementReadRepository := repository.NewAnnouncementReadRepository(client) announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository) announcementHandler := handler.NewAnnouncementHandler(announcementService) + channelMonitorRepository := repository.NewChannelMonitorRepository(client, db) + channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor) + channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db) dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig) @@ -192,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) @@ -221,20 +226,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) - sqlDB, err := repository.ProvideSQLDB(client) - if err != nil { - return nil, err - } - channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB) - channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB) + channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) + channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db) channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) - channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor) - channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) - channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) - channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) - availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) @@ -245,9 +241,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpHandler := handler.NewTotpHandler(totpService) handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService) paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) + availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -263,6 +260,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) + channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner) application := &Application{ Server: httpServer, diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 4277f0f1..2d4dcb5b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + AffiliateRebateRate: settings.AffiliateRebateRate, DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, @@ -338,6 +339,7 @@ type UpdateSettingsRequest struct { // 默认配置 DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` @@ -468,6 +470,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.DefaultBalance < 0 { req.DefaultBalance = 0 } + affiliateRebateRate := previousSettings.AffiliateRebateRate + if req.AffiliateRebateRate != nil { + affiliateRebateRate = *req.AffiliateRebateRate + } + if affiliateRebateRate < service.AffiliateRebateRateMin { + affiliateRebateRate = service.AffiliateRebateRateMin + } + if affiliateRebateRate > service.AffiliateRebateRateMax { + affiliateRebateRate = service.AffiliateRebateRateMax + } // 通用表格配置:兼容旧客户端未传字段时保留当前值。 if req.TableDefaultPageSize <= 0 { req.TableDefaultPageSize = previousSettings.TableDefaultPageSize @@ -1119,6 +1131,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + AffiliateRebateRate: affiliateRebateRate, DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, @@ -1433,6 +1446,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + AffiliateRebateRate: updatedSettings.AffiliateRebateRate, DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -1738,6 +1752,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } + if before.AffiliateRebateRate != after.AffiliateRebateRate { + changed = append(changed, "affiliate_rebate_rate") + } if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { changed = append(changed, "default_subscriptions") } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index dc68a466..1f9a66ff 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -48,6 +48,7 @@ type RegisterRequest struct { TurnstileToken string `json:"turnstile_token"` PromoCode string `json:"promo_code"` // 注册优惠码 InvitationCode string `json:"invitation_code"` // 邀请码 + AffCode string `json:"aff_code"` // 邀请返利码 } // SendVerifyCodeRequest 发送验证码请求 @@ -164,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) + _, user, err := h.authService.RegisterWithVerification( + c.Request.Context(), + req.Email, + req.Password, + req.VerifyCode, + req.PromoCode, + req.InvitationCode, + req.AffCode, + ) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 2affbc46..86074df7 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -108,6 +108,7 @@ type SystemSettings struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index f74c2b72..c386792c 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -14,10 +15,11 @@ import ( // UserHandler handles user-related requests type UserHandler struct { - userService *service.UserService - authService *service.AuthService - emailService *service.EmailService - emailCache service.EmailCache + userService *service.UserService + authService *service.AuthService + emailService *service.EmailService + emailCache service.EmailCache + affiliateService *service.AffiliateService } // NewUserHandler creates a new UserHandler @@ -35,6 +37,13 @@ func NewUserHandler( } } +func (h *UserHandler) SetAffiliateService(affiliateService *service.AffiliateService) { + if h == nil { + return + } + h.affiliateService = affiliateService +} + // ChangePasswordRequest represents the change password request payload type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` @@ -159,6 +168,63 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { response.Success(c, profileResp) } +func (h *UserHandler) affiliateServiceOrErr() (*service.AffiliateService, error) { + if h == nil || h.affiliateService == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return h.affiliateService, nil +} + +// GetAffiliate returns the current user's affiliate details. +// GET /api/v1/user/aff +func (h *UserHandler) GetAffiliate(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + affiliateSvc, err := h.affiliateServiceOrErr() + if err != nil { + response.ErrorFrom(c, err) + return + } + + detail, err := affiliateSvc.GetAffiliateDetail(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, detail) +} + +// TransferAffiliateQuota transfers all available affiliate quota into current balance. +// POST /api/v1/user/aff/transfer +func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + affiliateSvc, err := h.affiliateServiceOrErr() + if err != nil { + response.ErrorFrom(c, err) + return + } + + transferred, balance, err := affiliateSvc.TransferAffiliateQuota(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "transferred_quota": transferred, + "balance": balance, + }) +} + type StartIdentityBindingRequest struct { Provider string `json:"provider" binding:"required"` RedirectTo string `json:"redirect_to"` diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 6d175488..d4b34fd2 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -80,6 +80,18 @@ func ProvideSettingHandler(settingService *service.SettingService, buildInfo Bui return NewSettingHandler(settingService, buildInfo.Version) } +func ProvideUserHandler( + userService *service.UserService, + authService *service.AuthService, + emailService *service.EmailService, + emailCache service.EmailCache, + affiliateService *service.AffiliateService, +) *UserHandler { + handler := NewUserHandler(userService, authService, emailService, emailCache) + handler.SetAffiliateService(affiliateService) + return handler +} + // ProvideHandlers creates the Handlers struct func ProvideHandlers( authHandler *AuthHandler, @@ -125,7 +137,7 @@ func ProvideHandlers( var ProviderSet = wire.NewSet( // Top-level handlers NewAuthHandler, - NewUserHandler, + ProvideUserHandler, NewAPIKeyHandler, NewUsageHandler, NewRedeemHandler, diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go new file mode 100644 index 00000000..342ddf4f --- /dev/null +++ b/backend/internal/repository/affiliate_repo.go @@ -0,0 +1,420 @@ +package repository + +import ( + "context" + "crypto/rand" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +const ( + affiliateCodeLength = 12 + affiliateCodeMaxAttempts = 12 +) + +var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") + +type affiliateQueryExecer interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +type affiliateRepository struct { + client *dbent.Client +} + +func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository { + return &affiliateRepository{client: client} +} + +func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) { + if userID <= 0 { + return nil, service.ErrUserNotFound + } + client := clientFromContext(ctx, r.client) + return ensureUserAffiliateWithClient(ctx, client, userID) +} + +func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) { + client := clientFromContext(ctx, r.client) + return queryAffiliateByCode(ctx, client, code) +} + +func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) { + var bound bool + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil { + return err + } + + res, err := txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL", + inviterID, userID, + ) + if err != nil { + return fmt.Errorf("bind inviter: %w", err) + } + affected, _ := res.RowsAffected() + if affected == 0 { + bound = false + return nil + } + + if _, err = txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1", + inviterID, + ); err != nil { + return fmt.Errorf("increment inviter aff_count: %w", err) + } + bound = true + return nil + }) + if err != nil { + return false, err + } + return bound, nil +} + +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) { + if amount <= 0 { + return false, nil + } + + var applied bool + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + res, err := txClient.ExecContext(txCtx, + "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2", + amount, inviterID, + ) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + applied = false + return nil + } + + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } + + applied = true + return nil + }) + if err != nil { + return false, err + } + return applied, nil +} + +func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) { + var transferred float64 + var newBalance float64 + + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil { + return err + } + + rows, err := txClient.QueryContext(txCtx, ` +WITH claimed AS ( + SELECT aff_quota::double precision AS amount + FROM user_affiliates + WHERE user_id = $1 + AND aff_quota > 0 + FOR UPDATE +), +cleared AS ( + UPDATE user_affiliates ua + SET aff_quota = 0, + updated_at = NOW() + FROM claimed c + WHERE ua.user_id = $1 + RETURNING c.amount +) +SELECT amount +FROM cleared`, userID) + if err != nil { + return fmt.Errorf("claim affiliate quota: %w", err) + } + + if !rows.Next() { + _ = rows.Close() + if err := rows.Err(); err != nil { + return err + } + return service.ErrAffiliateQuotaEmpty + } + if err := rows.Scan(&transferred); err != nil { + _ = rows.Close() + return err + } + if err := rows.Close(); err != nil { + return err + } + if transferred <= 0 { + return service.ErrAffiliateQuotaEmpty + } + + affected, err := txClient.User.Update(). + Where(user.IDEQ(userID)). + AddBalance(transferred). + AddTotalRecharged(transferred). + Save(txCtx) + if err != nil { + return fmt.Errorf("credit user balance by affiliate quota: %w", err) + } + if affected == 0 { + return service.ErrUserNotFound + } + + newBalance, err = queryUserBalance(txCtx, txClient, userID) + if err != nil { + return err + } + + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) +VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil { + return fmt.Errorf("insert affiliate transfer ledger: %w", err) + } + + return nil + }) + if err != nil { + return 0, 0, err + } + + return transferred, newBalance, nil +} + +func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) { + if limit <= 0 { + limit = 100 + } + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.created_at +FROM user_affiliates ua +LEFT JOIN users u ON u.id = ua.user_id +WHERE ua.inviter_id = $1 +ORDER BY ua.created_at DESC +LIMIT $2`, inviterID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + invitees := make([]service.AffiliateInvitee, 0) + for rows.Next() { + var item service.AffiliateInvitee + var createdAt time.Time + if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil { + return nil, err + } + item.CreatedAt = &createdAt + invitees = append(invitees, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return invitees, nil +} + +func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx, tx.Client()) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return fmt.Errorf("begin affiliate transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx, tx.Client()); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit affiliate transaction: %w", err) + } + return nil +} + +func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) { + summary, err := queryAffiliateByUserID(ctx, client, userID) + if err == nil { + return summary, nil + } + if !errors.Is(err, service.ErrAffiliateProfileNotFound) { + return nil, err + } + + for i := 0; i < affiliateCodeMaxAttempts; i++ { + code, codeErr := generateAffiliateCode() + if codeErr != nil { + return nil, codeErr + } + _, insertErr := client.ExecContext(ctx, ` +INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at) +VALUES ($1, $2, NOW(), NOW()) +ON CONFLICT (user_id) DO NOTHING`, userID, code) + if insertErr == nil { + break + } + if isAffiliateUniqueViolation(insertErr) { + continue + } + return nil, insertErr + } + + return queryAffiliateByUserID(ctx, client, userID) +} + +func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) { + rows, err := client.QueryContext(ctx, ` +SELECT user_id, + aff_code, + inviter_id, + aff_count, + aff_quota::double precision, + aff_history_quota::double precision, + created_at, + updated_at +FROM user_affiliates +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrAffiliateProfileNotFound + } + + var out service.AffiliateSummary + var inviterID sql.NullInt64 + if err := rows.Scan( + &out.UserID, + &out.AffCode, + &inviterID, + &out.AffCount, + &out.AffQuota, + &out.AffHistoryQuota, + &out.CreatedAt, + &out.UpdatedAt, + ); err != nil { + return nil, err + } + if inviterID.Valid { + out.InviterID = &inviterID.Int64 + } + return &out, nil +} + +func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) { + rows, err := client.QueryContext(ctx, ` +SELECT user_id, + aff_code, + inviter_id, + aff_count, + aff_quota::double precision, + aff_history_quota::double precision, + created_at, + updated_at +FROM user_affiliates +WHERE aff_code = $1 +LIMIT 1`, strings.ToUpper(strings.TrimSpace(code))) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrAffiliateProfileNotFound + } + + var out service.AffiliateSummary + var inviterID sql.NullInt64 + if err := rows.Scan( + &out.UserID, + &out.AffCode, + &inviterID, + &out.AffCount, + &out.AffQuota, + &out.AffHistoryQuota, + &out.CreatedAt, + &out.UpdatedAt, + ); err != nil { + return nil, err + } + if inviterID.Valid { + out.InviterID = &inviterID.Int64 + } + return &out, nil +} + +func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) { + rows, err := client.QueryContext(ctx, + "SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1", + userID, + ) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return 0, err + } + return 0, service.ErrUserNotFound + } + var balance float64 + if err := rows.Scan(&balance); err != nil { + return 0, err + } + return balance, nil +} + +func generateAffiliateCode() (string, error) { + buf := make([]byte, affiliateCodeLength) + if _, err := rand.Read(buf); err != nil { + return "", fmt.Errorf("generate affiliate code: %w", err) + } + for i := range buf { + buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)] + } + return string(buf), nil +} + +func isAffiliateUniqueViolation(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) { + return string(pqErr.Code) == "23505" + } + return false +} diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go new file mode 100644 index 00000000..3ab5c0fb --- /dev/null +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -0,0 +1,114 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 { + t.Helper() + rows, err := client.QueryContext(ctx, query, args...) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next(), "expected one row") + var value float64 + require.NoError(t, rows.Scan(&value)) + require.NoError(t, rows.Err()) + return value +} + +func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int { + t.Helper() + rows, err := client.QueryContext(ctx, query, args...) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next(), "expected one row") + var value int + require.NoError(t, rows.Scan(&value)) + require.NoError(t, rows.Err()) + return value +} + +func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 5.5, + Concurrency: 5, + }) + + affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000) + _, err := client.ExecContext(txCtx, ` +INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at) +VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34) + require.NoError(t, err) + + transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID) + require.NoError(t, err) + require.InDelta(t, 12.34, transferred, 1e-9) + require.InDelta(t, 17.84, balance, 1e-9) + + affQuota := querySingleFloat(t, txCtx, client, + "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID) + require.InDelta(t, 0.0, affQuota, 1e-9) + + persistedBalance := querySingleFloat(t, txCtx, client, + "SELECT balance::double precision FROM users WHERE id = $1", u.ID) + require.InDelta(t, 17.84, persistedBalance, 1e-9) + + ledgerCount := querySingleInt(t, txCtx, client, + "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID) + require.Equal(t, 1, ledgerCount) +} + +func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewAffiliateRepository(client, integrationDB) + + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 3.21, + Concurrency: 5, + }) + + affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000) + _, err := client.ExecContext(txCtx, ` +INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at) +VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode) + require.NoError(t, err) + + transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID) + require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty) + require.InDelta(t, 0.0, transferred, 1e-9) + require.InDelta(t, 0.0, balance, 1e-9) + + persistedBalance := querySingleFloat(t, txCtx, client, + "SELECT balance::double precision FROM users WHERE id = $1", u.ID) + require.InDelta(t, 3.21, persistedBalance, 1e-9) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 6d24d312..f07bbb33 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet( NewChannelRepository, NewChannelMonitorRepository, NewChannelMonitorRequestTemplateRepository, + NewAffiliateRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e89ef3d9..35a6524a 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -715,6 +715,7 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "affiliate_rebate_rate": 20, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -895,6 +896,7 @@ func TestAPIContracts(t *testing.T) { "custom_endpoints": [], "default_concurrency": 0, "default_balance": 0, + "affiliate_rebate_rate": 20, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index babab125..9976954c 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,8 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.GET("/aff", h.User.GetAffiliate) + user.POST("/aff/transfer", h.User.TransferAffiliateQuota) user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode) user.POST("/account-bindings/email", h.User.BindEmailIdentity) user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go new file mode 100644 index 00000000..6fa5b423 --- /dev/null +++ b/backend/internal/service/affiliate_service.go @@ -0,0 +1,288 @@ +package service + +import ( + "context" + "errors" + "math" + "strconv" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +var ( + ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found") + ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code") + ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound") + ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer") +) + +const ( + affiliateInviteesLimit = 100 +) + +type AffiliateSummary struct { + UserID int64 `json:"user_id"` + AffCode string `json:"aff_code"` + InviterID *int64 `json:"inviter_id,omitempty"` + AffCount int `json:"aff_count"` + AffQuota float64 `json:"aff_quota"` + AffHistoryQuota float64 `json:"aff_history_quota"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type AffiliateInvitee struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + CreatedAt *time.Time `json:"created_at,omitempty"` +} + +type AffiliateDetail struct { + UserID int64 `json:"user_id"` + AffCode string `json:"aff_code"` + InviterID *int64 `json:"inviter_id,omitempty"` + AffCount int `json:"aff_count"` + AffQuota float64 `json:"aff_quota"` + AffHistoryQuota float64 `json:"aff_history_quota"` + Invitees []AffiliateInvitee `json:"invitees"` +} + +type AffiliateRepository interface { + EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) + GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) + BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) + TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) + ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error) +} + +type AffiliateService struct { + repo AffiliateRepository + settingRepo SettingRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + billingCacheService *BillingCacheService +} + +func NewAffiliateService(repo AffiliateRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService { + return &AffiliateService{ + repo: repo, + settingRepo: settingRepo, + authCacheInvalidator: authCacheInvalidator, + billingCacheService: billingCacheService, + } +} + +func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER", "invalid user") + } + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.EnsureUserAffiliate(ctx, userID) +} + +func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) { + summary, err := s.EnsureUserAffiliate(ctx, userID) + if err != nil { + return nil, err + } + invitees, err := s.listInvitees(ctx, userID) + if err != nil { + return nil, err + } + return &AffiliateDetail{ + UserID: summary.UserID, + AffCode: summary.AffCode, + InviterID: summary.InviterID, + AffCount: summary.AffCount, + AffQuota: summary.AffQuota, + AffHistoryQuota: summary.AffHistoryQuota, + Invitees: invitees, + }, nil +} + +func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error { + code := strings.ToUpper(strings.TrimSpace(rawCode)) + if code == "" { + return nil + } + if s == nil || s.repo == nil { + return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + + selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID) + if err != nil { + return err + } + if selfSummary.InviterID != nil { + return nil + } + + inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code) + if err != nil { + if errors.Is(err, ErrAffiliateProfileNotFound) { + return ErrAffiliateCodeInvalid + } + return err + } + if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID { + return ErrAffiliateCodeInvalid + } + + bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID) + if err != nil { + return err + } + if !bound { + return ErrAffiliateAlreadyBound + } + return nil +} + +func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) { + if s == nil || s.repo == nil { + return 0, nil + } + if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) { + return 0, nil + } + + inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID) + if err != nil { + return 0, err + } + if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 { + return 0, nil + } + + rebateRatePercent := s.loadAffiliateRebateRatePercent(ctx) + rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8) + if rebate <= 0 { + return 0, nil + } + + if _, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID); err != nil { + return 0, err + } + + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate) + if err != nil { + return 0, err + } + if !applied { + return 0, nil + } + return rebate, nil +} + +func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) { + if s == nil || s.repo == nil { + return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + + transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID) + if err != nil { + return 0, 0, err + } + if transferred > 0 { + s.invalidateAffiliateCaches(ctx, userID) + } + return transferred, balance, nil +} + +func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) { + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit) + if err != nil { + return nil, err + } + for i := range invitees { + invitees[i].Email = maskEmail(invitees[i].Email) + } + return invitees, nil +} + +func (s *AffiliateService) loadAffiliateRebateRatePercent(ctx context.Context) float64 { + if s == nil || s.settingRepo == nil { + return AffiliateRebateRateDefault + } + + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate) + if err != nil { + return AffiliateRebateRateDefault + } + + rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64) + if err != nil { + return AffiliateRebateRateDefault + } + if math.IsNaN(rate) || math.IsInf(rate, 0) { + return AffiliateRebateRateDefault + } + if rate < AffiliateRebateRateMin { + return AffiliateRebateRateMin + } + if rate > AffiliateRebateRateMax { + return AffiliateRebateRateMax + } + return rate +} + +func roundTo(v float64, scale int) float64 { + factor := math.Pow10(scale) + return math.Round(v*factor) / factor +} + +func maskEmail(email string) string { + email = strings.TrimSpace(email) + if email == "" { + return "" + } + at := strings.Index(email, "@") + if at <= 0 || at >= len(email)-1 { + return "***" + } + + local := email[:at] + domain := email[at+1:] + dot := strings.LastIndex(domain, ".") + + maskedLocal := maskSegment(local) + if dot <= 0 || dot >= len(domain)-1 { + return maskedLocal + "@" + maskSegment(domain) + } + + domainName := domain[:dot] + tld := domain[dot:] + return maskedLocal + "@" + maskSegment(domainName) + tld +} + +func maskSegment(s string) string { + r := []rune(s) + if len(r) == 0 { + return "***" + } + if len(r) == 1 { + return string(r[0]) + "***" + } + return string(r[0]) + "***" +} + +func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) { + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + }() + } +} diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go new file mode 100644 index 00000000..6adf879d --- /dev/null +++ b/backend/internal/service/affiliate_service_test.go @@ -0,0 +1,59 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type affiliateSettingRepoStub struct { + value string + err error +} + +func (s *affiliateSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, s.err } +func (s *affiliateSettingRepoStub) GetValue(context.Context, string) (string, error) { + if s.err != nil { + return "", s.err + } + return s.value, nil +} +func (s *affiliateSettingRepoStub) Set(context.Context, string, string) error { return s.err } +func (s *affiliateSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + if s.err != nil { + return nil, s.err + } + return map[string]string{}, nil +} +func (s *affiliateSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return s.err +} +func (s *affiliateSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + if s.err != nil { + return nil, s.err + } + return map[string]string{}, nil +} +func (s *affiliateSettingRepoStub) Delete(context.Context, string) error { return s.err } + +func TestAffiliateRebateRatePercentSemantics(t *testing.T) { + t.Parallel() + + svc := &AffiliateService{settingRepo: &affiliateSettingRepoStub{value: "1"}} + rate := svc.loadAffiliateRebateRatePercent(context.Background()) + require.Equal(t, 1.0, rate) + + svc.settingRepo = &affiliateSettingRepoStub{value: "0.2"} + rate = svc.loadAffiliateRebateRatePercent(context.Background()) + require.Equal(t, 0.2, rate) +} + +func TestMaskEmail(t *testing.T) { + t.Parallel() + require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com")) + require.Equal(t, "x***@d***", maskEmail("x@domain")) + require.Equal(t, "", maskEmail("")) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index e45d8d66..fe0c32f5 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -72,6 +72,7 @@ type AuthService struct { turnstileService *TurnstileService emailQueueService *EmailQueueService promoService *PromoService + affiliateService *AffiliateService defaultSubAssigner DefaultSubscriptionAssigner } @@ -121,13 +122,26 @@ func (s *AuthService) EntClient() *dbent.Client { return s.entClient } +func (s *AuthService) SetAffiliateService(affiliateService *AffiliateService) { + if s == nil { + return + } + s.affiliateService = affiliateService +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") } -// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户 -func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) { +// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。 +// affiliateCode 使用可选参数以兼容旧调用方。 +func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string, affiliateCode ...string) (string, *User, error) { + affiliateCodeRaw := "" + if len(affiliateCode) > 0 { + affiliateCodeRaw = affiliateCode[0] + } + // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled @@ -223,6 +237,17 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } s.postAuthUserBootstrap(ctx, user, "email", true) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + if s.affiliateService != nil { + if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err) + } + if code := strings.TrimSpace(affiliateCodeRaw); code != "" { + if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil { + // 邀请返利码绑定失败不影响注册,只记录日志 + logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err) + } + } + } // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cf47b76f..23afeb87 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -18,6 +18,13 @@ const ( RoleUser = domain.RoleUser ) +// Affiliate rebate settings +const ( + AffiliateRebateRateDefault = 20.0 + AffiliateRebateRateMin = 0.0 + AffiliateRebateRateMax = 100.0 +) + // Platform constants const ( PlatformAnthropic = domain.PlatformAnthropic @@ -87,6 +94,7 @@ const ( SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接 SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 + SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 243edff3..c6167447 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -268,6 +269,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e switch action { case redeemActionSkipCompleted: + s.applyAffiliateRebateForOrder(ctx, o) // Code already created and redeemed — just mark completed return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") case redeemActionCreate: @@ -281,6 +283,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } + s.applyAffiliateRebateForOrder(ctx, o) return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") } @@ -358,6 +361,139 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action return c > 0 } +func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) { + if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 { + return + } + if s.affiliateService == nil { + return + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("begin affiliate rebate tx: %v", err), + }) + return + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + if !claimed { + return + } + + rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) + if err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + + if rebateAmount <= 0 { + if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{ + "baseAmount": o.Amount, + "reason": "no inviter bound or rebate amount <= 0", + }); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + if err := tx.Commit(); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), + }) + } + return + } + + if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{ + "baseAmount": o.Amount, + "rebateAmount": rebateAmount, + }); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": err.Error(), + }) + return + } + + if err := tx.Commit(); err != nil { + s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ + "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), + }) + } +} + +func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) { + if client == nil { + return false, errors.New("nil payment client") + } + oid := strconv.FormatInt(orderID, 10) + detail, _ := json.Marshal(map[string]any{ + "baseAmount": baseAmount, + "status": "reserved", + }) + rows, err := client.QueryContext(ctx, ` +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW() +WHERE NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs + WHERE order_id = $1 + AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') +) +ON CONFLICT (order_id, action) DO NOTHING +RETURNING id`, oid, string(detail)) + if err != nil { + return false, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + if err := rows.Err(); err != nil { + return false, err + } + return false, nil + } + var claimID int64 + if err := rows.Scan(&claimID); err != nil { + return false, err + } + return true, nil +} + +func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error { + if client == nil { + return errors.New("nil payment client") + } + oid := strconv.FormatInt(orderID, 10) + detailJSON, _ := json.Marshal(detail) + updated, err := client.PaymentAuditLog.Update(). + Where( + paymentauditlog.OrderIDEQ(oid), + paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"), + ). + SetAction(action). + SetDetail(string(detailJSON)). + SetOperator("system"). + Save(ctx) + if err != nil { + return err + } + if updated == 0 { + return errors.New("affiliate rebate claim log not found") + } + return nil +} + func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) { now := time.Now() r := psErrMsg(cause) diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 97fd76a0..15f6feeb 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -170,17 +170,18 @@ type TopUserStat struct { // --- Service --- type PaymentService struct { - providerMu sync.Mutex - providersLoaded bool - entClient *dbent.Client - registry *payment.Registry - loadBalancer payment.LoadBalancer - redeemService *RedeemService - subscriptionSvc *SubscriptionService - configService *PaymentConfigService - userRepo UserRepository - groupRepo GroupRepository - resumeService *PaymentResumeService + providerMu sync.Mutex + providersLoaded bool + entClient *dbent.Client + registry *payment.Registry + loadBalancer payment.LoadBalancer + redeemService *RedeemService + subscriptionSvc *SubscriptionService + configService *PaymentConfigService + userRepo UserRepository + groupRepo GroupRepository + resumeService *PaymentResumeService + affiliateService *AffiliateService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { @@ -189,6 +190,13 @@ func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, load return svc } +func (s *PaymentService) SetAffiliateService(affiliateService *AffiliateService) { + if s == nil { + return + } + s.affiliateService = affiliateService +} + // --- Provider Registry --- // EnsureProviders lazily initializes the provider registry on first call. diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index c79d8949..f3801c48 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "math" "net/url" "sort" "strconv" @@ -1167,6 +1168,8 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) + updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { @@ -1719,6 +1722,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", @@ -1846,6 +1850,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } + if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil { + result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate) + } else { + result.AffiliateRebateRate = AffiliateRebateRateDefault + } result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 @@ -2130,6 +2139,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin return result } +func clampAffiliateRebateRate(value float64) float64 { + if math.IsNaN(value) || math.IsInf(value, 0) { + return AffiliateRebateRateDefault + } + if value < AffiliateRebateRateMin { + return AffiliateRebateRateMin + } + if value > AffiliateRebateRateMax { + return AffiliateRebateRateMax + } + return value +} + func isFalseSettingValue(value string) bool { switch strings.ToLower(strings.TrimSpace(value)) { case "false", "0", "off", "disabled": diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ddd4fff6..8a3bd421 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -106,6 +106,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + AffiliateRebateRate float64 DefaultUserRPMLimit int DefaultSubscriptions []DefaultSubscriptionSetting diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 86bfc327..d8a6a332 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -391,6 +391,53 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit return svc } +func ProvideAuthService( + entClient *dbent.Client, + userRepo UserRepository, + redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, + cfg *config.Config, + settingService *SettingService, + emailService *EmailService, + turnstileService *TurnstileService, + emailQueueService *EmailQueueService, + promoService *PromoService, + defaultSubAssigner DefaultSubscriptionAssigner, + affiliateService *AffiliateService, +) *AuthService { + svc := NewAuthService( + entClient, + userRepo, + redeemRepo, + refreshTokenCache, + cfg, + settingService, + emailService, + turnstileService, + emailQueueService, + promoService, + defaultSubAssigner, + ) + svc.SetAffiliateService(affiliateService) + return svc +} + +func ProvidePaymentService( + entClient *dbent.Client, + registry *payment.Registry, + loadBalancer payment.LoadBalancer, + redeemService *RedeemService, + subscriptionSvc *SubscriptionService, + configService *PaymentConfigService, + userRepo UserRepository, + groupRepo GroupRepository, + affiliateService *AffiliateService, +) *PaymentService { + svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo) + svc.SetAffiliateService(affiliateService) + return svc +} + // ProvideBillingCacheService wires BillingCacheService with its RPM dependencies. func ProvideBillingCacheService( cache BillingCache, @@ -407,7 +454,7 @@ func ProvideBillingCacheService( // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services - NewAuthService, + ProvideAuthService, NewUserService, NewAPIKeyService, ProvideAPIKeyAuthCacheInvalidator, @@ -486,8 +533,9 @@ var ProviderSet = wire.NewSet( NewGroupCapacityService, NewChannelService, NewModelPricingResolver, + NewAffiliateService, ProvidePaymentConfigService, - NewPaymentService, + ProvidePaymentService, ProvidePaymentOrderExpiryService, ProvideBalanceNotifyService, ProvideChannelMonitorService, diff --git a/backend/migrations/130_add_user_affiliates.sql b/backend/migrations/130_add_user_affiliates.sql new file mode 100644 index 00000000..d8c001e0 --- /dev/null +++ b/backend/migrations/130_add_user_affiliates.sql @@ -0,0 +1,20 @@ +CREATE TABLE IF NOT EXISTS user_affiliates ( + user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + aff_code VARCHAR(32) NOT NULL UNIQUE, + inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + aff_count INTEGER NOT NULL DEFAULT 0, + aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0, + aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id); +CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota); + +COMMENT ON TABLE user_affiliates IS '用户邀请返利信息'; +COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码'; +COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID'; +COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数'; +COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额'; +COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额'; diff --git a/backend/migrations/131_affiliate_rebate_hardening.sql b/backend/migrations/131_affiliate_rebate_hardening.sql new file mode 100644 index 00000000..81e37a9e --- /dev/null +++ b/backend/migrations/131_affiliate_rebate_hardening.sql @@ -0,0 +1,58 @@ +-- 1) Normalize historical affiliate rebate rate values. +-- Legacy compatibility treated 0 20%). +-- We now use pure percentage semantics, so convert persisted fractional values once. +UPDATE settings +SET value = to_char((value::numeric * 100), 'FM999999990.########'), + updated_at = NOW() +WHERE key = 'affiliate_rebate_rate' + AND value ~ '^-?[0-9]+(\\.[0-9]+)?$' + AND value::numeric > 0 + AND value::numeric <= 1; + +-- 2) Affiliate ledger for accrual/transfer traceability. +CREATE TABLE IF NOT EXISTS user_affiliate_ledger ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + action VARCHAR(32) NOT NULL, + amount DECIMAL(20,8) NOT NULL, + source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id); +CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action); + +COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)'; +COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer'; + +-- 3) Enforce idempotency at DB layer for payment audit actions. +WITH ranked AS ( + SELECT id, + ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn + FROM payment_audit_logs +) +DELETE FROM payment_audit_logs p +USING ranked r +WHERE p.id = r.id + AND r.rn > 1; + +CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq +ON payment_audit_logs(order_id, action); + +-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders. +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT po.id::text, + 'AFFILIATE_REBATE_SKIPPED', + '{"reason":"baseline before affiliate rebate idempotency rollout"}', + 'system', + NOW() +FROM payment_orders po +WHERE po.order_type = 'balance' + AND po.status = 'COMPLETED' + AND NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs pal + WHERE pal.order_id = po.id::text + AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') + ); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index b9f24663..971c2314 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -308,6 +308,7 @@ export interface SystemSettings { totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings default_balance: number; + affiliate_rebate_rate: number; default_concurrency: number; default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; @@ -489,6 +490,7 @@ export interface UpdateSettingsRequest { invitation_code_enabled?: boolean; totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; + affiliate_rebate_rate?: number; default_concurrency?: number; default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index fd3cedb9..da7a91eb 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -9,7 +9,14 @@ import { prepareOAuthBindAccessTokenCookie, type WeChatOAuthPublicSettings, } from './auth' -import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types' +import type { + User, + ChangePasswordRequest, + NotifyEmailEntry, + UserAuthProvider, + UserAffiliateDetail, + AffiliateTransferResponse +} from '@/types' /** * Get current user profile @@ -168,6 +175,16 @@ export async function startOAuthBinding( window.location.href = startURL } +export async function getAffiliateDetail(): Promise { + const { data } = await apiClient.get('/user/aff') + return data +} + +export async function transferAffiliateQuota(): Promise { + const { data } = await apiClient.post('/user/aff/transfer') + return data +} + export const userAPI = { getProfile, updateProfile, @@ -180,7 +197,9 @@ export const userAPI = { bindEmailIdentity, unbindAuthIdentity, buildOAuthBindingStartURL, - startOAuthBinding + startOAuthBinding, + getAffiliateDetail, + transferAffiliateQuota } export default userAPI diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 910f24cd..a3a8c30e 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -656,6 +656,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] { { path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment }, { path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment }, { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, + { path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true }, { path: '/profile', label: t('nav.profile'), icon: UserIcon }, ...customMenuItemsForUser.value.map((item): NavItem => ({ path: `/custom/${item.id}`, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index eeb6087b..e7514f0e 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -346,6 +346,7 @@ export default { apiKeys: 'API Keys', usage: 'Usage', redeem: 'Redeem', + affiliate: 'Affiliate Rebates', profile: 'Profile', users: 'Users', groups: 'Groups', @@ -972,6 +973,47 @@ export default { } }, + affiliate: { + title: 'Affiliate Rebates', + description: 'Invite new users and convert your rebate quota into account balance', + yourCode: 'Your Affiliate Code', + inviteLink: 'Invite Link', + copyCode: 'Copy Code', + copyLink: 'Copy Link', + codeCopied: 'Affiliate code copied', + linkCopied: 'Invite link copied', + loadFailed: 'Failed to load affiliate data', + transferFailed: 'Failed to transfer affiliate quota', + stats: { + invitedUsers: 'Invited Users', + availableQuota: 'Available Rebate Quota', + totalQuota: 'Historical Rebate Quota' + }, + transfer: { + title: 'Transfer Rebate Quota', + description: 'Move available rebate quota into your account balance', + button: 'Transfer to Balance', + transferring: 'Transferring...', + empty: 'No available rebate quota', + success: '{amount} has been transferred to your balance' + }, + invitees: { + title: 'Invited Users', + empty: 'No invited users yet', + columns: { + email: 'Email', + username: 'Username', + joinedAt: 'Joined At' + } + }, + tips: { + title: 'How It Works', + line1: 'Share your affiliate code or invite link with new users.', + line2: 'When invitees recharge, you receive rebate quota based on the configured rate.', + line3: 'Transfer rebate quota to balance at any time.' + } + }, + // Redeem redeem: { title: 'Redeem Code', @@ -4837,6 +4879,9 @@ export default { description: 'Default values for new users', defaultBalance: 'Default Balance', defaultBalanceHint: 'Initial balance for new users', + affiliateRebateRate: 'Affiliate Rebate Rate', + affiliateRebateRateHint: + 'Rebate percentage credited to inviter after recharge (0-100%, e.g. 10 means 10%)', defaultConcurrency: 'Default Concurrency', defaultConcurrencyHint: 'Maximum concurrent requests for new users', defaultUserRpmLimit: 'Default User RPM Limit', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 5f91a2f6..3057f93e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -346,6 +346,7 @@ export default { apiKeys: 'API 密钥', usage: '使用记录', redeem: '兑换', + affiliate: '邀请返利', profile: '个人资料', users: '用户管理', groups: '分组管理', @@ -976,6 +977,47 @@ export default { } }, + affiliate: { + title: '邀请返利', + description: '邀请新用户注册,并将返利额度转入账户余额', + yourCode: '我的邀请码', + inviteLink: '邀请链接', + copyCode: '复制邀请码', + copyLink: '复制链接', + codeCopied: '邀请码已复制', + linkCopied: '邀请链接已复制', + loadFailed: '加载邀请返利数据失败', + transferFailed: '转入余额失败', + stats: { + invitedUsers: '邀请人数', + availableQuota: '可转返利额度', + totalQuota: '历史返利额度' + }, + transfer: { + title: '返利额度转余额', + description: '将当前可用返利额度一键转入账户余额', + button: '转入余额', + transferring: '转入中...', + empty: '当前没有可转入额度', + success: '已转入余额:{amount}' + }, + invitees: { + title: '已邀请用户', + empty: '暂无邀请记录', + columns: { + email: '邮箱', + username: '用户名', + joinedAt: '注册时间' + } + }, + tips: { + title: '使用说明', + line1: '将邀请码或邀请链接分享给新用户。', + line2: '被邀请用户充值后,你可获得对应比例的返利额度。', + line3: '返利额度可随时转入账户余额。' + } + }, + // Redeem redeem: { title: '兑换码', @@ -5000,6 +5042,8 @@ export default { description: '新用户的默认值', defaultBalance: '默认余额', defaultBalanceHint: '新用户的初始余额', + affiliateRebateRate: '邀请返利比例', + affiliateRebateRateHint: '充值后返给邀请人的比例(0-100%,例如填写 10 表示返利 10%)', defaultConcurrency: '默认并发数', defaultConcurrencyHint: '新用户的最大并发请求数', defaultUserRpmLimit: '默认用户 RPM 限制', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index dc886b23..2b85b97e 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -197,6 +197,18 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'redeem.description' } }, + { + path: '/affiliate', + name: 'Affiliate', + component: () => import('@/views/user/AffiliateView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: false, + title: 'Affiliate', + titleKey: 'affiliate.title', + descriptionKey: 'affiliate.description' + } + }, { path: '/available-channels', name: 'UserAvailableChannels', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 2329bb25..e2f41900 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -122,6 +122,29 @@ export interface RegisterRequest { turnstile_token?: string promo_code?: string invitation_code?: string + aff_code?: string +} + +export interface AffiliateInvitee { + user_id: number + email: string + username: string + created_at?: string +} + +export interface UserAffiliateDetail { + user_id: number + aff_code: string + inviter_id?: number | null + aff_count: number + aff_quota: number + aff_history_quota: number + invitees: AffiliateInvitee[] +} + +export interface AffiliateTransferResponse { + transferred_quota: number + balance: number } export interface SendVerifyCodeRequest { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 3e167938..6da4b21a 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2153,6 +2153,31 @@ {{ t("admin.settings.defaults.defaultBalanceHint") }}

+
+ +
+ + % +
+

+ {{ t("admin.settings.defaults.affiliateRebateRateHint") }} +

+
+
+ +