From b2bdba78dd15c4454d9236ea1a50258d68bbff98 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 3 May 2026 14:56:09 +0800 Subject: [PATCH 1/8] stabilize image request handling --- backend/internal/service/openai_images.go | 62 ++++++++++- .../internal/service/openai_images_test.go | 103 ++++++++++++++++++ 2 files changed, 159 insertions(+), 6 deletions(-) diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 4badcb1c..3da76525 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -596,7 +596,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( var usage OpenAIUsage imageCount := parsed.N var firstTokenMs *int - if parsed.Stream { + if parsed.Stream && isEventStreamResponse(resp.Header) { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { return nil, err @@ -811,6 +811,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int + var fallbackBody bytes.Buffer + fallbackBytes := int64(0) + fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg) + seenSSEData := false + fallbackTooLarge := false for { line, err := reader.ReadBytes('\n') @@ -824,11 +829,24 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( } flusher.Flush() - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { - dataBytes := []byte(data) - mergeOpenAIUsage(&usage, dataBytes) - if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { - imageCount = count + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok { + if data != "" && data != "[DONE]" { + seenSSEData = true + fallbackBody.Reset() + fallbackBytes = 0 + dataBytes := []byte(data) + mergeOpenAIUsage(&usage, dataBytes) + if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount { + imageCount = count + } + } + } else if !seenSSEData && !fallbackTooLarge { + fallbackBytes += int64(len(line)) + if fallbackBytes <= fallbackLimit { + _, _ = fallbackBody.Write(line) + } else { + fallbackTooLarge = true + fallbackBody.Reset() } } } @@ -839,9 +857,41 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( return OpenAIUsage{}, 0, firstTokenMs, err } } + if !seenSSEData && fallbackBody.Len() > 0 { + body := bytes.TrimSpace(fallbackBody.Bytes()) + if len(body) > 0 { + mergeOpenAIUsage(&usage, body) + if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount { + imageCount = count + } + } + } return usage, imageCount, firstTokenMs, nil } +func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int { + if count := extractOpenAIImageCountFromJSONBytes(body); count > 0 { + return count + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + if count := int(gjson.GetBytes(body, "usage.images").Int()); count > 0 { + return count + } + if count := int(gjson.GetBytes(body, "tool_usage.image_gen.images").Int()); count > 0 { + return count + } + eventType := strings.TrimSpace(gjson.GetBytes(body, "type").String()) + if eventType == "" || !strings.HasSuffix(eventType, ".completed") { + return 0 + } + if gjson.GetBytes(body, "b64_json").Exists() || gjson.GetBytes(body, "url").Exists() { + return 1 + } + return 0 +} + func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { if dst == nil { return diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 47113d4d..681e0e8e 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -446,6 +446,109 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseU require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamJSONResponseBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_stream_json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"usage":{"input_tokens":12,"output_tokens":21,"output_tokens_details":{"image_tokens":9}},"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 7, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 21, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.ImageOutputTokens) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbackBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_json_mislabeled"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000009,"usage":{"input_tokens":10,"output_tokens":18,"output_tokens_details":{"image_tokens":8}},"data":[{"b64_json":"ZmluYWw="}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) + require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) { + body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`) + + require.Equal(t, 1, extractOpenAIImagesBillableCountFromJSONBytes(body)) +} + func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { gin.SetMode(gin.TestMode) From 72d5ee4cd1d57d2648dcc3e23218834505cbc402 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 3 May 2026 17:11:27 +0800 Subject: [PATCH 2/8] fix: drain OpenAI compat streams for usage --- .../pkg/apicompat/anthropic_responses_test.go | 39 ++ .../chatcompletions_responses_test.go | 43 ++ .../pkg/apicompat/responses_to_anthropic.go | 4 +- .../apicompat/responses_to_chatcompletions.go | 4 +- backend/internal/pkg/apicompat/types.go | 2 +- backend/internal/service/gateway_service.go | 7 + .../service/gateway_service_streaming_test.go | 13 + .../service/openai_compat_model_test.go | 283 +++++++++++++ .../openai_gateway_chat_completions.go | 221 +++++----- .../openai_gateway_chat_completions_test.go | 258 ++++++++++++ .../service/openai_gateway_messages.go | 376 +++++++++++++----- .../service/openai_gateway_service.go | 4 +- .../service/openai_oauth_passthrough_test.go | 92 +++++ 13 files changed, 1141 insertions(+), 205 deletions(-) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index e8b25c2b..edde85d3 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 12, events[0].Usage.InputTokens) + assert.Equal(t, 4, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + +func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "max_tokens", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { state := NewResponsesEventToAnthropicState() ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 35d42999..bf5c23d5 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) { assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) } +func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { state := NewResponsesEventToChatState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 489ed238..b76f384d 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents( return resToAnthHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return resToAnthHandleBlockDone(state) - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToAnthHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 61b3bf9c..2386771d 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent return resToChatHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return nil - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToChatHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index f8c6b75f..0ff2cf49 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct { type ResponsesStreamEvent struct { Type string `json:"type"` - // response.created / response.completed / response.failed / response.incomplete + // response.created / response.completed / response.done / response.failed / response.incomplete Response *ResponsesResponse `json:"response,omitempty"` // response.output_item.added / response.output_item.done diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 074013c3..67d19720 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance } func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } if !stream { return ctx, func() {} } + return context.WithoutCancel(ctx), func() {} +} + +func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { return context.Background(), func() {} } diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go index c8803d39..39a7d3b0 100644 --- a/backend/internal/service/gateway_service_streaming_test.go +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +type upstreamContextTestKey string + func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 7, result.usage.OutputTokens) } + +func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) { + parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value")) + upstreamCtx, release := detachUpstreamContext(parent) + defer release() + + cancel() + + require.NoError(t, upstreamCtx.Err()) + require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key"))) +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 4396c15f..1129bf04 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -3,13 +3,16 @@ package service import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -18,6 +21,51 @@ import ( "github.com/tidwall/gjson" ) +type openAICompatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAICompatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +type openAICompatBlockingReadCloser struct { + data []byte + offset int + closed chan struct{} + closeOnce sync.Once +} + +func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser { + return &openAICompatBlockingReadCloser{ + data: data, + closed: make(chan struct{}), + } +} + +func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) { + if r.offset < len(r.data) { + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil + } + <-r.closed + return 0, io.EOF +} + +func (r *openAICompatBlockingReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.closed) + }) + return nil +} + func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { t.Parallel() @@ -228,3 +276,238 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon require.NotNil(t, result) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) } + +func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 5822ae4c..3456cce0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai chat_completions buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai chat_completions buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ RequestID: requestID, @@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } chunks := apicompat.ResponsesEventToChatChunks(&event, state) - for _, chunk := range chunks { - sse, err := apicompat.ChatChunkToSSE(chunk) - if err != nil { - logger.L().Warn("openai chat_completions stream: failed to marshal chunk", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai chat_completions stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(chunks) > 0 { + if len(chunks) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } finalizeStream := func() (*OpenAIForwardResult, error) { - if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } } } // Send [DONE] sentinel - fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck - c.Writer.Flush() + if !clientDisconnected { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during done flush", + zap.String("request_id", requestID), + ) + } + } + if !clientDisconnected { + c.Writer.Flush() + } return resultWithUsage(), nil } @@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // Determine keepalive interval keepaliveInterval := time.Duration(0) @@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } // No keepalive: fast synchronous path - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // With keepalive: goroutine + channel + select @@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { select { case ev, ok := <-events: if !ok { - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai chat_completions stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 6846e03a..1236fb2c 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -1,13 +1,36 @@ package service import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) +type openAIChatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAIChatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestNormalizeResponsesRequestServiceTier(t *testing.T) { t.Parallel() @@ -73,3 +96,238 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } + +func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 4e0ebb2e..9fd6f04c 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai messages buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - // Terminal events carry the complete ResponsesResponse with output + usage. - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai messages buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( }, nil } +func isOpenAICompatResponsesTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.incomplete", "response.failed": + return true + default: + return false + } +} + +func isOpenAICompatDoneSentinelLine(line string) bool { + payload, ok := extractOpenAISSEDataLine(line) + return ok && strings.TrimSpace(payload) == "[DONE]" +} + +func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( + resp *http.Response, + logPrefix string, + requestID string, +) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) { + acc := apicompat.NewBufferedResponseAccumulator() + var usage OpenAIUsage + if resp == nil || resp.Body == nil { + return nil, usage, acc, errors.New("upstream response body is nil") + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var timeoutCh <-chan time.Time + var timeoutTimer *time.Timer + resetTimeout := func() { + if streamInterval <= 0 { + return + } + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(streamInterval) + timeoutCh = timeoutTimer.C + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + timeoutTimer.Reset(streamInterval) + } + stopTimeout := func() { + if timeoutTimer == nil { + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + } + resetTimeout() + defer stopTimeout() + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + go func() { + defer close(events) + for scanner.Scan() { + select { + case events <- scanEvent{line: scanner.Text()}: + case <-done: + return + } + } + if err := scanner.Err(); err != nil { + select { + case events <- scanEvent{err: err}: + case <-done: + } + } + }() + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return nil, usage, acc, nil + } + resetTimeout() + if ev.err != nil { + if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) { + logger.L().Warn(logPrefix+": read error", + zap.Error(ev.err), + zap.String("request_id", requestID), + ) + } + return nil, usage, acc, ev.err + } + + payload, ok := extractOpenAISSEDataLine(ev.line) + if !ok || payload == "" { + continue + } + if strings.TrimSpace(payload) == "[DONE]" { + return nil, usage, acc, nil + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn(logPrefix+": failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + acc.ProcessEvent(&event) + + if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } + return event.Response, usage, acc, nil + } + + case <-timeoutCh: + _ = resp.Body.Close() + logger.L().Warn(logPrefix+": data interval timeout", + zap.String("request_id", requestID), + zap.Duration("interval", streamInterval), + ) + return nil, usage, acc, fmt.Errorf("stream data interval timeout") + } + } +} + // handleAnthropicStreamingResponse reads Responses SSE events from upstream, // converts each to Anthropic SSE events, and writes them to the client. // When StreamKeepaliveInterval is configured, it uses a goroutine + channel @@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ @@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // processDataLine handles a single "data: ..." SSE line from upstream. - // Returns (clientDisconnected bool). processDataLine := func(payload string) bool { if firstChunk { firstChunk = false @@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } // Convert to Anthropic events events := apicompat.ResponsesEventToAnthropicEvents(&event, state) - for _, evt := range events { - sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) - if err != nil { - logger.L().Warn("openai messages stream: failed to marshal event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai messages stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(events) > 0 { + if len(events) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } // finalizeStream sends any remaining Anthropic events and returns the result. finalizeStream := func() (*OpenAIForwardResult, error) { - if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected { for _, evt := range finalEvents { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() } - c.Writer.Flush() } return resultWithUsage(), nil } @@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // ── Determine keepalive interval ── keepaliveInterval := time.Duration(0) @@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // ── No keepalive: fast synchronous path (no goroutine overhead) ── - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // ── With keepalive: goroutine + channel + select ── @@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { @@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( case ev, ok := <-events: if !ok { // Upstream closed - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai messages stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( logger.L().Info("openai messages stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } @@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string }, }) } + +func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage { + if usage == nil { + return OpenAIUsage{} + } + result := OpenAIUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + } + if usage.InputTokensDetails != nil { + result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens + } + return result +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ed69730c..d1d73586 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco httpInvalidEncryptedContentRetryTried := false for { // Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) releaseUpstreamCtx() if err != nil { @@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) releaseUpstreamCtx() if err != nil { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 049ffdd8..87a05b14 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) } +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te require.Contains(t, string(upstream.lastBody), `"stream":true`) } +func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { gin.SetMode(gin.TestMode) From 47fb38bca13fd53d0c85e9db113c5405a6d6b3e4 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 3 May 2026 17:11:52 +0800 Subject: [PATCH 3/8] fix: record zero OpenAI usage logs --- .../service/openai_compat_model_test.go | 8 ++- .../service/openai_gateway_403_reset_test.go | 19 +++++-- .../openai_gateway_chat_completions_test.go | 8 ++- .../service/openai_gateway_messages.go | 18 +++---- .../openai_gateway_record_usage_test.go | 50 +++++++++++++++++++ .../service/openai_gateway_service.go | 7 --- 6 files changed, 85 insertions(+), 25 deletions(-) diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 1129bf04..840784bf 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -336,7 +336,9 @@ func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing. upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) - defer upstreamStream.Close() + defer func() { + require.NoError(t, upstreamStream.Close()) + }() upstream := &httpUpstreamRecorder{resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}}, @@ -389,7 +391,9 @@ func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testi upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) - defer upstreamStream.Close() + defer func() { + require.NoError(t, upstreamStream.Close()) + }() upstream := &httpUpstreamRecorder{resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}}, diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go index c6805464..440b94a9 100644 --- a/backend/internal/service/openai_gateway_403_reset_test.go +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -20,20 +20,29 @@ func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accou return nil } -func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterForZeroUsage(t *testing.T) { counter := &openAI403CounterResetStub{} rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) rateLimitSvc.SetOpenAI403CounterCache(counter) - svc := &OpenAIGatewayService{ - rateLimitService: rateLimitSvc, - } + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + svc.rateLimitService = rateLimitSvc err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ - Result: &OpenAIForwardResult{}, + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage_reset_403", + Model: "gpt-5.1", + }, + APIKey: &APIKey{ID: 1001, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2001}, Account: &Account{ID: 777, Platform: PlatformOpenAI}, }) require.NoError(t, err) require.Equal(t, []int64{777}, counter.resetCalls) + require.Equal(t, 1, usageRepo.calls) } diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 1236fb2c..c129a4df 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -156,7 +156,9 @@ func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *te upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) - defer upstreamStream.Close() + defer func() { + require.NoError(t, upstreamStream.Close()) + }() upstream := &httpUpstreamRecorder{resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}}, @@ -209,7 +211,9 @@ func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) - defer upstreamStream.Close() + defer func() { + require.NoError(t, upstreamStream.Close()) + }() upstream := &httpUpstreamRecorder{resp: &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}}, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 9fd6f04c..5f3bf5c1 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -441,13 +441,13 @@ func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( return nil, usage, acc, ev.err } + if isOpenAICompatDoneSentinelLine(ev.line) { + return nil, usage, acc, nil + } payload, ok := extractOpenAISSEDataLine(ev.line) if !ok || payload == "" { continue } - if strings.TrimSpace(payload) == "[DONE]" { - return nil, usage, acc, nil - } var event apicompat.ResponsesStreamEvent if err := json.Unmarshal([]byte(payload), &event); err != nil { @@ -640,13 +640,13 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } payload, ok := extractOpenAISSEDataLine(line) if !ok { continue } - if strings.TrimSpace(payload) == "[DONE]" { - return missingTerminalErr() - } if processDataLine(payload) { return finalizeStream() } @@ -713,13 +713,13 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } lastDataAt = time.Now() line := ev.line + if isOpenAICompatDoneSentinelLine(line) { + return missingTerminalErr() + } payload, ok := extractOpenAISSEDataLine(line) if !ok { continue } - if strings.TrimSpace(payload) == "[DONE]" { - return missingTerminalErr() - } if processDataLine(payload) { return finalizeStream() } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 47ff4e3b..76fbb794 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -186,6 +186,56 @@ func max(a, b int) int { return b } +func TestOpenAIGatewayServiceRecordUsage_ZeroUsageStillWritesUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_zero_usage", + Usage: OpenAIUsage{}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000, Quota: 100, Group: &Group{RateMultiplier: 1}}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000, Type: AccountTypeAPIKey}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) + require.Equal(t, 0, quotaSvc.rateLimitCalls) + + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "resp_zero_usage", usageRepo.lastLog.RequestID) + require.Zero(t, usageRepo.lastLog.InputTokens) + require.Zero(t, usageRepo.lastLog.OutputTokens) + require.Zero(t, usageRepo.lastLog.CacheCreationTokens) + require.Zero(t, usageRepo.lastLog.CacheReadTokens) + require.Zero(t, usageRepo.lastLog.ImageOutputTokens) + require.Zero(t, usageRepo.lastLog.ImageCount) + require.Zero(t, usageRepo.lastLog.InputCost) + require.Zero(t, usageRepo.lastLog.OutputCost) + require.Zero(t, usageRepo.lastLog.TotalCost) + require.Zero(t, usageRepo.lastLog.ActualCost) + + require.NotNil(t, billingRepo.lastCmd) + require.Zero(t, billingRepo.lastCmd.BalanceCost) + require.Zero(t, billingRepo.lastCmd.SubscriptionCost) + require.Zero(t, billingRepo.lastCmd.APIKeyQuotaCost) + require.Zero(t, billingRepo.lastCmd.APIKeyRateLimitCost) + require.Zero(t, billingRepo.lastCmd.AccountQuotaCost) +} + func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T) { groupID := int64(11) groupRate := 1.4 diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d1d73586..b818fa4a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -5041,13 +5041,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) } - // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 - if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && - result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && - result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { - return nil - } - apiKey := input.APIKey user := input.User account := input.Account From 6a41cf6a51795e3564362c68b7cd72a62338deca Mon Sep 17 00:00:00 2001 From: lyen1688 Date: Sun, 3 May 2026 15:43:56 +0800 Subject: [PATCH 4/8] feat: add admin affiliate record pages --- .../handler/admin/affiliate_handler.go | 108 +++++ backend/internal/repository/affiliate_repo.go | 364 +++++++++++++++++ backend/internal/server/routes/admin.go | 5 + backend/internal/service/affiliate_service.go | 125 ++++++ frontend/src/api/admin/affiliates.ts | 121 ++++++ frontend/src/components/layout/AppSidebar.vue | 13 + frontend/src/i18n/locales/en.ts | 47 +++ frontend/src/i18n/locales/zh.ts | 47 +++ frontend/src/router/index.ts | 40 ++ .../affiliates/AdminAffiliateInvitesView.vue | 7 + .../affiliates/AdminAffiliateRebatesView.vue | 7 + .../affiliates/AdminAffiliateRecordsTable.vue | 386 ++++++++++++++++++ .../AdminAffiliateTransfersView.vue | 7 + 13 files changed, 1277 insertions(+) create mode 100644 frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue create mode 100644 frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue create mode 100644 frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue create mode 100644 frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go index 97e649ec..d443d344 100644 --- a/backend/internal/handler/admin/affiliate_handler.go +++ b/backend/internal/handler/admin/affiliate_handler.go @@ -2,8 +2,11 @@ package admin import ( "strconv" + "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -181,3 +184,108 @@ func (h *AffiliateHandler) LookupUsers(c *gin.Context) { } response.Success(c, result) } + +// GetUserOverview returns one user's affiliate overview. +// GET /api/v1/admin/affiliates/users/:user_id/overview +func (h *AffiliateHandler) GetUserOverview(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64) + if err != nil || userID <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + overview, err := h.affiliateService.AdminGetUserOverview(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, overview) +} + +// ListInviteRecords returns all inviter-invitee relationships. +// GET /api/v1/admin/affiliates/invites +func (h *AffiliateHandler) ListInviteRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListInviteRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListRebateRecords returns all order-level affiliate rebate records. +// GET /api/v1/admin/affiliates/rebates +func (h *AffiliateHandler) ListRebateRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListRebateRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +// ListTransferRecords returns all affiliate quota-to-balance transfer records. +// GET /api/v1/admin/affiliates/transfers +func (h *AffiliateHandler) ListTransferRecords(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + filter := parseAffiliateRecordFilter(c, page, pageSize) + items, total, err := h.affiliateService.AdminListTransferRecords(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, items, total, filter.Page, filter.PageSize) +} + +func parseAffiliateRecordFilter(c *gin.Context, page, pageSize int) service.AffiliateRecordFilter { + filter := service.AffiliateRecordFilter{ + Search: c.Query("search"), + Page: page, + PageSize: pageSize, + SortBy: c.Query("sort_by"), + SortDesc: c.Query("sort_order") != "asc", + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + userTZ := c.Query("timezone") + if t := parseAffiliateRecordStartTime(c.Query("start_at"), userTZ); t != nil { + filter.StartAt = t + } + if t := parseAffiliateRecordEndTime(c.Query("end_at"), userTZ); t != nil { + filter.EndAt = t + } + return filter +} + +func parseAffiliateRecordStartTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + return &parsed + } + return nil +} + +func parseAffiliateRecordEndTime(raw string, userTZ string) *time.Time { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + return &parsed + } + if parsed, err := timezone.ParseInUserLocation("2006-01-02", raw, userTZ); err == nil { + end := parsed.AddDate(0, 0, 1).Add(-time.Nanosecond) + return &end + } + return nil +} diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index ef89e5b6..7ff497b6 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "database/sql" + "encoding/json" "errors" "fmt" "strings" @@ -332,6 +333,369 @@ LIMIT $2`, inviterID, limit) return invitees, nil } +func (r *affiliateRepository) ListAffiliateInviteRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateInviteRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ua.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "ua.inviter_id::text", "ua.user_id::text", "inviter_aff.aff_code", + }) + + total, err := queryAffiliateRecordCount(ctx, client, ` +SELECT COUNT(*) +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +`+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "inviter": "inviter.email", + "invitee": "invitee.email", + "aff_code": "inviter_aff.aff_code", + "total_rebate": "total_rebate", + "created_at": "ua.created_at", + }, "ua.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ua.inviter_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + ua.user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + COALESCE(inviter_aff.aff_code, ''), + COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate, + ua.created_at +FROM user_affiliates ua +JOIN users invitee ON invitee.id = ua.user_id +JOIN users inviter ON inviter.id = ua.inviter_id +JOIN user_affiliates inviter_aff ON inviter_aff.user_id = ua.inviter_id +LEFT JOIN user_affiliate_ledger ual + ON ual.user_id = ua.inviter_id + AND ual.source_user_id = ua.user_id + AND ual.action = 'accrue' +`+where+` +GROUP BY ua.inviter_id, inviter.email, inviter.username, ua.user_id, invitee.email, invitee.username, inviter_aff.aff_code, ua.created_at +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateInviteRecord, 0) + for rows.Next() { + var item service.AffiliateInviteRecord + if err := rows.Scan( + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.AffCode, + &item.TotalRebate, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateRebateRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateRebateRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "pal.created_at", []string{ + "inviter.email", "inviter.username", "invitee.email", "invitee.username", + "po.id::text", "po.out_trade_no", "po.payment_type", "po.status", + }) + baseJoin := ` +FROM payment_audit_logs pal +JOIN payment_orders po ON po.id::text = pal.order_id +JOIN user_affiliates invitee_aff ON invitee_aff.user_id = po.user_id +JOIN users invitee ON invitee.id = po.user_id +JOIN users inviter ON inviter.id = invitee_aff.inviter_id +WHERE pal.action = 'AFFILIATE_REBATE_APPLIED'` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "order": "po.id", + "inviter": "inviter.email", + "invitee": "invitee.email", + "order_amount": "po.amount", + "pay_amount": "po.pay_amount", + "payment_type": "po.payment_type", + "order_status": "po.status", + "created_at": "pal.created_at", + }, "pal.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT po.id, + po.out_trade_no, + invitee_aff.inviter_id, + COALESCE(inviter.email, ''), + COALESCE(inviter.username, ''), + po.user_id, + COALESCE(invitee.email, ''), + COALESCE(invitee.username, ''), + po.amount::double precision, + po.pay_amount::double precision, + COALESCE(pal.detail, ''), + po.payment_type, + po.status, + pal.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateRebateRecord, 0) + for rows.Next() { + var item service.AffiliateRebateRecord + var detail string + if err := rows.Scan( + &item.OrderID, + &item.OutTradeNo, + &item.InviterID, + &item.InviterEmail, + &item.InviterUsername, + &item.InviteeID, + &item.InviteeEmail, + &item.InviteeUsername, + &item.OrderAmount, + &item.PayAmount, + &detail, + &item.PaymentType, + &item.OrderStatus, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + item.RebateAmount = parseAffiliateRebateAmount(detail) + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) ListAffiliateTransferRecords(ctx context.Context, filter service.AffiliateRecordFilter) ([]service.AffiliateTransferRecord, int64, error) { + client := clientFromContext(ctx, r.client) + where, args := buildAffiliateRecordWhere(filter, "ual.created_at", []string{ + "u.email", "u.username", "u.id::text", + }) + baseJoin := ` +FROM user_affiliate_ledger ual +JOIN users u ON u.id = ual.user_id +JOIN user_affiliates ua ON ua.user_id = ual.user_id +WHERE ual.action = 'transfer'` + if where != "" { + where = strings.Replace(where, "WHERE ", " AND ", 1) + } + + total, err := queryAffiliateRecordCount(ctx, client, "SELECT COUNT(*) "+baseJoin+where, args...) + if err != nil { + return nil, 0, err + } + + orderBy := buildAffiliateRecordOrderBy(filter, map[string]string{ + "user": "u.email", + "amount": "ual.amount", + "current_balance": "u.balance", + "remaining_quota": "ua.aff_quota", + "frozen_quota": "ua.aff_frozen_quota", + "history_quota": "ua.aff_history_quota", + "created_at": "ual.created_at", + }, "ual.created_at") + args = append(args, filter.PageSize, (filter.Page-1)*filter.PageSize) + rows, err := client.QueryContext(ctx, ` +SELECT ual.id, + ual.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ual.amount::double precision, + u.balance::double precision, + ua.aff_quota::double precision, + ua.aff_frozen_quota::double precision, + ua.aff_history_quota::double precision, + ual.created_at +`+baseJoin+where+` +`+orderBy+` +LIMIT $`+fmt.Sprint(len(args)-1)+` OFFSET $`+fmt.Sprint(len(args)), args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + items := make([]service.AffiliateTransferRecord, 0) + for rows.Next() { + var item service.AffiliateTransferRecord + if err := rows.Scan( + &item.LedgerID, + &item.UserID, + &item.UserEmail, + &item.Username, + &item.Amount, + &item.CurrentBalance, + &item.RemainingQuota, + &item.FrozenQuota, + &item.HistoryQuota, + &item.CreatedAt, + ); err != nil { + return nil, 0, err + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return items, total, nil +} + +func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, userID int64) (*service.AffiliateUserOverview, error) { + if userID <= 0 { + return nil, service.ErrUserNotFound + } + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.aff_code, + COALESCE(ua.aff_rebate_rate_percent, 0)::double precision, + (ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate, + ua.aff_count, + COALESCE(rebated.rebated_invitee_count, 0), + ua.aff_quota::double precision, + ua.aff_history_quota::double precision +FROM user_affiliates ua +JOIN users u ON u.id = ua.user_id +LEFT JOIN ( + SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count + FROM user_affiliate_ledger + WHERE action = 'accrue' AND source_user_id IS NOT NULL + GROUP BY user_id +) rebated ON rebated.user_id = ua.user_id +WHERE ua.user_id = $1 +LIMIT 1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUserNotFound + } + + var overview service.AffiliateUserOverview + var customRate float64 + var hasCustomRate bool + if err := rows.Scan( + &overview.UserID, + &overview.Email, + &overview.Username, + &overview.AffCode, + &customRate, + &hasCustomRate, + &overview.InvitedCount, + &overview.RebatedInviteeCount, + &overview.AvailableQuota, + &overview.HistoryQuota, + ); err != nil { + return nil, err + } + if hasCustomRate { + overview.RebateRatePercent = customRate + overview.RebateRateCustom = true + } + return &overview, rows.Err() +} + +func buildAffiliateRecordWhere(filter service.AffiliateRecordFilter, timeColumn string, searchColumns []string) (string, []any) { + clauses := make([]string, 0, 3) + args := make([]any, 0, 3) + if filter.StartAt != nil { + args = append(args, *filter.StartAt) + clauses = append(clauses, fmt.Sprintf("%s >= $%d", timeColumn, len(args))) + } + if filter.EndAt != nil { + args = append(args, *filter.EndAt) + clauses = append(clauses, fmt.Sprintf("%s <= $%d", timeColumn, len(args))) + } + search := strings.TrimSpace(filter.Search) + if search != "" && len(searchColumns) > 0 { + args = append(args, "%"+strings.ToLower(search)+"%") + parts := make([]string, 0, len(searchColumns)) + for _, col := range searchColumns { + parts = append(parts, fmt.Sprintf("LOWER(%s) LIKE $%d", col, len(args))) + } + clauses = append(clauses, "("+strings.Join(parts, " OR ")+")") + } + if len(clauses) == 0 { + return "", args + } + return "WHERE " + strings.Join(clauses, " AND "), args +} + +func buildAffiliateRecordOrderBy(filter service.AffiliateRecordFilter, sortColumns map[string]string, fallbackColumn string) string { + column := sortColumns[filter.SortBy] + if column == "" { + column = fallbackColumn + } + direction := "DESC" + if !filter.SortDesc { + direction = "ASC" + } + return "ORDER BY " + column + " " + direction +} + +func queryAffiliateRecordCount(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) { + rows, err := client.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + if !rows.Next() { + return 0, rows.Err() + } + var total int64 + if err := rows.Scan(&total); err != nil { + return 0, err + } + return total, rows.Err() +} + +func parseAffiliateRebateAmount(detail string) float64 { + var payload struct { + RebateAmount float64 `json:"rebateAmount"` + } + if err := json.Unmarshal([]byte(detail), &payload); err != nil { + return 0 + } + return payload.RebateAmount +} + func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { if tx := dbent.TxFromContext(ctx); tx != nil { return fn(ctx, tx.Client()) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 1c786f50..fe4c4b1b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -602,11 +602,16 @@ func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) { affiliates := admin.Group("/affiliates") { + affiliates.GET("/invites", h.Admin.Affiliate.ListInviteRecords) + affiliates.GET("/rebates", h.Admin.Affiliate.ListRebateRecords) + affiliates.GET("/transfers", h.Admin.Affiliate.ListTransferRecords) + users := affiliates.Group("/users") { users.GET("", h.Admin.Affiliate.ListUsers) users.GET("/lookup", h.Admin.Affiliate.LookupUsers) users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate) + users.GET("/:user_id/overview", h.Admin.Affiliate.GetUserOverview) users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings) users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings) } diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index 5a4e91e7..d8a59135 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -110,6 +110,10 @@ type AffiliateRepository interface { SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) + ListAffiliateInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) + ListAffiliateRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) + ListAffiliateTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) + GetAffiliateUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) } // AffiliateAdminFilter 列表筛选条件 @@ -130,6 +134,71 @@ type AffiliateAdminEntry struct { AffCount int `json:"aff_count"` } +type AffiliateRecordFilter struct { + Search string + Page int + PageSize int + StartAt *time.Time + EndAt *time.Time + SortBy string + SortDesc bool +} + +type AffiliateInviteRecord struct { + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + AffCode string `json:"aff_code"` + TotalRebate float64 `json:"total_rebate"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateRebateRecord struct { + OrderID int64 `json:"order_id"` + OutTradeNo string `json:"out_trade_no"` + InviterID int64 `json:"inviter_id"` + InviterEmail string `json:"inviter_email"` + InviterUsername string `json:"inviter_username"` + InviteeID int64 `json:"invitee_id"` + InviteeEmail string `json:"invitee_email"` + InviteeUsername string `json:"invitee_username"` + OrderAmount float64 `json:"order_amount"` + PayAmount float64 `json:"pay_amount"` + RebateAmount float64 `json:"rebate_amount"` + PaymentType string `json:"payment_type"` + OrderStatus string `json:"order_status"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateTransferRecord struct { + LedgerID int64 `json:"ledger_id"` + UserID int64 `json:"user_id"` + UserEmail string `json:"user_email"` + Username string `json:"username"` + Amount float64 `json:"amount"` + CurrentBalance float64 `json:"current_balance"` + RemainingQuota float64 `json:"remaining_quota"` + FrozenQuota float64 `json:"frozen_quota"` + HistoryQuota float64 `json:"history_quota"` + CreatedAt time.Time `json:"created_at"` +} + +type AffiliateUserOverview struct { + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + AffCode string `json:"aff_code"` + RebateRatePercent float64 `json:"rebate_rate_percent"` + RebateRateCustom bool `json:"-"` + InvitedCount int `json:"invited_count"` + RebatedInviteeCount int `json:"rebated_invitee_count"` + AvailableQuota float64 `json:"available_quota"` + HistoryQuota float64 `json:"history_quota"` +} + type AffiliateService struct { repo AffiliateRepository settingService *SettingService @@ -488,3 +557,59 @@ func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter Affi } return s.repo.ListUsersWithCustomSettings(ctx, filter) } + +func (s *AffiliateService) AdminListInviteRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateInviteRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListRebateRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateRebateRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminListTransferRecords(ctx context.Context, filter AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) { + if s == nil || s.repo == nil { + return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + return s.repo.ListAffiliateTransferRecords(ctx, normalizeAffiliateRecordFilter(filter)) +} + +func (s *AffiliateService) AdminGetUserOverview(ctx context.Context, userID int64) (*AffiliateUserOverview, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_USER", "invalid user") + } + if s == nil || s.repo == nil { + return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") + } + overview, err := s.repo.GetAffiliateUserOverview(ctx, userID) + if err != nil { + return nil, err + } + if overview != nil { + if !overview.RebateRateCustom { + overview.RebateRatePercent = s.globalRebateRatePercent(ctx) + } + overview.RebateRatePercent = clampAffiliateRebateRate(overview.RebateRatePercent) + } + return overview, nil +} + +func normalizeAffiliateRecordFilter(filter AffiliateRecordFilter) AffiliateRecordFilter { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.PageSize > 100 { + filter.PageSize = 100 + } + filter.Search = strings.TrimSpace(filter.Search) + filter.SortBy = strings.TrimSpace(filter.SortBy) + return filter +} diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts index 22639bd2..37b03f00 100644 --- a/frontend/src/api/admin/affiliates.ts +++ b/frontend/src/api/admin/affiliates.ts @@ -23,6 +23,71 @@ export interface ListAffiliateUsersParams { search?: string } +export interface ListAffiliateRecordsParams { + page?: number + page_size?: number + search?: string + start_at?: string + end_at?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + timezone?: string +} + +export interface AffiliateInviteRecord { + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + aff_code: string + total_rebate: number + created_at: string +} + +export interface AffiliateRebateRecord { + order_id: number + out_trade_no: string + inviter_id: number + inviter_email: string + inviter_username: string + invitee_id: number + invitee_email: string + invitee_username: string + order_amount: number + pay_amount: number + rebate_amount: number + payment_type: string + order_status: string + created_at: string +} + +export interface AffiliateTransferRecord { + ledger_id: number + user_id: number + user_email: string + username: string + amount: number + current_balance: number + remaining_quota: number + frozen_quota: number + history_quota: number + created_at: string +} + +export interface AffiliateUserOverview { + user_id: number + email: string + username: string + aff_code: string + rebate_rate_percent: number + invited_count: number + rebated_invitee_count: number + available_quota: number + history_quota: number +} + export interface UpdateAffiliateUserRequest { aff_code?: string aff_rebate_rate_percent?: number | null @@ -97,12 +162,68 @@ export async function batchSetRate( return data } +function recordParams(params: ListAffiliateRecordsParams = {}) { + return { + page: params.page ?? 1, + page_size: params.page_size ?? 20, + search: params.search ?? '', + start_at: params.start_at || undefined, + end_at: params.end_at || undefined, + sort_by: params.sort_by || undefined, + sort_order: params.sort_order || undefined, + timezone: params.timezone || undefined, + } +} + +export async function listInviteRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/invites', + { params: recordParams(params) }, + ) + return data +} + +export async function listRebateRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/rebates', + { params: recordParams(params) }, + ) + return data +} + +export async function listTransferRecords( + params: ListAffiliateRecordsParams = {}, +): Promise> { + const { data } = await apiClient.get>( + '/admin/affiliates/transfers', + { params: recordParams(params) }, + ) + return data +} + +export async function getUserOverview( + userId: number, +): Promise { + const { data } = await apiClient.get( + `/admin/affiliates/users/${userId}/overview`, + ) + return data +} + export const affiliatesAPI = { listUsers, lookupUsers, updateUserSettings, clearUserSettings, batchSetRate, + listInviteRecords, + listRebateRecords, + listTransferRecords, + getUserOverview, } export default affiliatesAPI diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index d8e2794e..4488bf60 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -721,6 +721,19 @@ const adminNavItems = computed((): NavItem[] => { { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, + { + path: '/admin/affiliates', + label: t('nav.affiliateManagement'), + icon: UsersIcon, + hideInSimpleMode: true, + expandOnly: true, + featureFlag: flagAffiliate, + children: [ + { path: '/admin/affiliates/invites', label: t('nav.affiliateInviteRecords'), icon: UsersIcon }, + { path: '/admin/affiliates/rebates', label: t('nav.affiliateRebateRecords'), icon: OrderIcon }, + { path: '/admin/affiliates/transfers', label: t('nav.affiliateTransferRecords'), icon: CreditCardIcon }, + ], + }, { path: '/admin/orders', label: t('nav.orderManagement'), diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 2da121fb..50b19d2a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -347,6 +347,10 @@ export default { usage: 'Usage', redeem: 'Redeem', affiliate: 'Affiliate Rebates', + affiliateManagement: 'Affiliate Rebates', + affiliateInviteRecords: 'Invite Records', + affiliateRebateRecords: 'Rebate Records', + affiliateTransferRecords: 'Transfer Records', profile: 'Profile', users: 'Users', groups: 'Groups', @@ -1635,6 +1639,49 @@ export default { } }, + affiliates: { + invitesDescription: 'View site-wide inviter and invitee relationships', + rebatesDescription: 'View recharge orders that generated affiliate rebates', + transfersDescription: 'View affiliate quota transfers into account balance', + errors: { + loadFailed: 'Failed to load affiliate records' + }, + records: { + search: 'Search', + searchPlaceholder: 'Email, username, user ID, or order number', + startAt: 'Start date', + endAt: 'End date', + inviter: 'Inviter', + invitee: 'Invitee', + user: 'User', + affCode: 'Invite Code', + order: 'Order', + totalRebate: 'Total Rebate', + orderAmount: 'Top-up Amount', + payAmount: 'Paid Amount', + rebateAmount: 'Rebate Amount', + paymentType: 'Payment Method', + orderStatus: 'Order Status', + transferAmount: 'Transfer Amount', + currentBalance: 'Current Balance', + remainingQuota: 'Remaining Quota', + frozenQuota: 'Frozen Rebate', + historyQuota: 'Historical Rebate', + invitedAt: 'Invited At', + rebatedAt: 'Rebated At', + transferredAt: 'Transferred At' + }, + overview: { + title: 'Affiliate User Overview', + affCode: 'Invite Code', + rebateRate: 'Rebate Rate', + invitedCount: 'Invited Users', + rebatedInviteeCount: 'Rebated Invitees', + availableQuota: 'Available Quota', + historyQuota: 'Historical Rebate' + } + }, + // Users users: { title: 'User Management', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 7d266522..ac5735f5 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -347,6 +347,10 @@ export default { usage: '使用记录', redeem: '兑换', affiliate: '邀请返利', + affiliateManagement: '邀请返利', + affiliateInviteRecords: '邀请记录', + affiliateRebateRecords: '返利记录', + affiliateTransferRecords: '提取记录', profile: '个人资料', users: '用户管理', groups: '分组管理', @@ -1656,6 +1660,49 @@ export default { } }, + affiliates: { + invitesDescription: '查看全站邀请关系和被邀请用户累计返利', + rebatesDescription: '查看每一笔产生返利的充值订单', + transfersDescription: '查看返利额度转入账户余额的提取流水', + errors: { + loadFailed: '加载邀请返利记录失败' + }, + records: { + search: '搜索', + searchPlaceholder: '邮箱、用户名、用户 ID、订单号', + startAt: '开始日期', + endAt: '结束日期', + inviter: '邀请人', + invitee: '被邀请人', + user: '用户', + affCode: '邀请码', + order: '订单', + totalRebate: '累计返利', + orderAmount: '充值金额', + payAmount: '支付金额', + rebateAmount: '返利金额', + paymentType: '支付方式', + orderStatus: '订单状态', + transferAmount: '提取金额', + currentBalance: '当前余额', + remainingQuota: '剩余可提取', + frozenQuota: '冻结返利', + historyQuota: '历史返利', + invitedAt: '邀请时间', + rebatedAt: '返利时间', + transferredAt: '提取时间' + }, + overview: { + title: '用户返利概览', + affCode: '邀请码', + rebateRate: '返利比例', + invitedCount: '邀请人数', + rebatedInviteeCount: '已产生返利人数', + availableQuota: '可提余额', + historyQuota: '历史返利' + } + }, + // Users Management users: { title: '用户管理', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 06f6b212..238f6a71 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -517,6 +517,46 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.usage.description' } }, + { + path: '/admin/affiliates', + redirect: '/admin/affiliates/invites' + }, + { + path: '/admin/affiliates/invites', + name: 'AdminAffiliateInvites', + component: () => import('@/views/admin/affiliates/AdminAffiliateInvitesView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Invite Records', + titleKey: 'nav.affiliateInviteRecords', + descriptionKey: 'admin.affiliates.invitesDescription' + } + }, + { + path: '/admin/affiliates/rebates', + name: 'AdminAffiliateRebates', + component: () => import('@/views/admin/affiliates/AdminAffiliateRebatesView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Rebate Records', + titleKey: 'nav.affiliateRebateRecords', + descriptionKey: 'admin.affiliates.rebatesDescription' + } + }, + { + path: '/admin/affiliates/transfers', + name: 'AdminAffiliateTransfers', + component: () => import('@/views/admin/affiliates/AdminAffiliateTransfersView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Affiliate Transfer Records', + titleKey: 'nav.affiliateTransferRecords', + descriptionKey: 'admin.affiliates.transfersDescription' + } + }, // ==================== Payment Admin Routes ==================== diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue new file mode 100644 index 00000000..62c96ff8 --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateInvitesView.vue @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue new file mode 100644 index 00000000..1acd7b1b --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateRebatesView.vue @@ -0,0 +1,7 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue new file mode 100644 index 00000000..42b379f8 --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue @@ -0,0 +1,386 @@ + + + diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue b/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue new file mode 100644 index 00000000..5a56f179 --- /dev/null +++ b/frontend/src/views/admin/affiliates/AdminAffiliateTransfersView.vue @@ -0,0 +1,7 @@ + + + From 0a914e034cdc264120ac629ed7ed6bc9920d5b9f Mon Sep 17 00:00:00 2001 From: lyen1688 Date: Sun, 3 May 2026 16:44:59 +0800 Subject: [PATCH 5/8] fix: include matured affiliate quota in admin overview --- backend/internal/repository/affiliate_repo.go | 50 +++++++++++-------- .../repository/affiliate_repo_test.go | 15 ++++++ 2 files changed, 44 insertions(+), 21 deletions(-) create mode 100644 backend/internal/repository/affiliate_repo_test.go diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index 7ff497b6..793e1032 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -23,6 +23,34 @@ const ( var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") +const affiliateUserOverviewSQL = ` +SELECT ua.user_id, + COALESCE(u.email, ''), + COALESCE(u.username, ''), + ua.aff_code, + COALESCE(ua.aff_rebate_rate_percent, 0)::double precision, + (ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate, + ua.aff_count, + COALESCE(rebated.rebated_invitee_count, 0), + (ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0))::double precision, + ua.aff_history_quota::double precision +FROM user_affiliates ua +JOIN users u ON u.id = ua.user_id +LEFT JOIN ( + SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count + FROM user_affiliate_ledger + WHERE action = 'accrue' AND source_user_id IS NOT NULL + GROUP BY user_id +) rebated ON rebated.user_id = ua.user_id +LEFT JOIN ( + SELECT user_id, COALESCE(SUM(amount), 0)::double precision AS matured_frozen_quota + FROM user_affiliate_ledger + WHERE action = 'accrue' AND frozen_until IS NOT NULL AND frozen_until <= NOW() + GROUP BY user_id +) matured ON matured.user_id = ua.user_id +WHERE ua.user_id = $1 +LIMIT 1` + type affiliateQueryExecer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) @@ -575,27 +603,7 @@ func (r *affiliateRepository) GetAffiliateUserOverview(ctx context.Context, user return nil, service.ErrUserNotFound } client := clientFromContext(ctx, r.client) - rows, err := client.QueryContext(ctx, ` -SELECT ua.user_id, - COALESCE(u.email, ''), - COALESCE(u.username, ''), - ua.aff_code, - COALESCE(ua.aff_rebate_rate_percent, 0)::double precision, - (ua.aff_rebate_rate_percent IS NOT NULL) AS has_custom_rate, - ua.aff_count, - COALESCE(rebated.rebated_invitee_count, 0), - ua.aff_quota::double precision, - ua.aff_history_quota::double precision -FROM user_affiliates ua -JOIN users u ON u.id = ua.user_id -LEFT JOIN ( - SELECT user_id, COUNT(DISTINCT source_user_id)::integer AS rebated_invitee_count - FROM user_affiliate_ledger - WHERE action = 'accrue' AND source_user_id IS NOT NULL - GROUP BY user_id -) rebated ON rebated.user_id = ua.user_id -WHERE ua.user_id = $1 -LIMIT 1`, userID) + rows, err := client.QueryContext(ctx, affiliateUserOverviewSQL, userID) if err != nil { return nil, err } diff --git a/backend/internal/repository/affiliate_repo_test.go b/backend/internal/repository/affiliate_repo_test.go new file mode 100644 index 00000000..03999fa9 --- /dev/null +++ b/backend/internal/repository/affiliate_repo_test.go @@ -0,0 +1,15 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAffiliateUserOverviewSQLIncludesMaturedFrozenQuota(t *testing.T) { + query := strings.Join(strings.Fields(affiliateUserOverviewSQL), " ") + + require.Contains(t, query, "ua.aff_quota + COALESCE(matured.matured_frozen_quota, 0)") + require.Contains(t, query, "frozen_until <= NOW()") +} From 650ddb2e398d1b73d3fccfd820487788cc79139a Mon Sep 17 00:00:00 2001 From: lyen1688 Date: Sun, 3 May 2026 16:54:43 +0800 Subject: [PATCH 6/8] fix: make affiliate record users clickable --- .../admin/affiliates/AdminAffiliateRecordsTable.vue | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue index 42b379f8..74416a4a 100644 --- a/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue +++ b/frontend/src/views/admin/affiliates/AdminAffiliateRecordsTable.vue @@ -31,7 +31,7 @@ :id="row.inviter_id" :email="row.inviter_email" :username="row.inviter_username" - :clickable="props.type === 'invites'" + :clickable="props.type !== 'transfers'" @open="openUserOverview" /> @@ -40,12 +40,18 @@ :id="row.invitee_id" :email="row.invitee_email" :username="row.invitee_username" - :clickable="props.type === 'invites'" + :clickable="props.type !== 'transfers'" @open="openUserOverview" />