From 72d5ee4cd1d57d2648dcc3e23218834505cbc402 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 3 May 2026 17:11:27 +0800 Subject: [PATCH] fix: drain OpenAI compat streams for usage --- .../pkg/apicompat/anthropic_responses_test.go | 39 ++ .../chatcompletions_responses_test.go | 43 ++ .../pkg/apicompat/responses_to_anthropic.go | 4 +- .../apicompat/responses_to_chatcompletions.go | 4 +- backend/internal/pkg/apicompat/types.go | 2 +- backend/internal/service/gateway_service.go | 7 + .../service/gateway_service_streaming_test.go | 13 + .../service/openai_compat_model_test.go | 283 +++++++++++++ .../openai_gateway_chat_completions.go | 221 +++++----- .../openai_gateway_chat_completions_test.go | 258 ++++++++++++ .../service/openai_gateway_messages.go | 376 +++++++++++++----- .../service/openai_gateway_service.go | 4 +- .../service/openai_oauth_passthrough_test.go | 92 +++++ 13 files changed, 1141 insertions(+), 205 deletions(-) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index e8b25c2b..edde85d3 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "end_turn", events[0].Delta.StopReason) + assert.Equal(t, 12, events[0].Usage.InputTokens) + assert.Equal(t, 4, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + +func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4}, + }, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, "max_tokens", events[0].Delta.StopReason) + assert.Equal(t, "message_stop", events[1].Type) + assert.Nil(t, FinalizeResponsesAnthropicStream(state)) +} + func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { state := NewResponsesEventToAnthropicState() ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 35d42999..bf5c23d5 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) { assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) } +func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.done", + Response: &ResponsesResponse{ + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7}, + }, + }, state) + require.Len(t, chunks, 2) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 13, chunks[1].Usage.PromptTokens) + assert.Equal(t, 7, chunks[1].Usage.CompletionTokens) + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { state := NewResponsesEventToChatState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 489ed238..b76f384d 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents( return resToAnthHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return resToAnthHandleBlockDone(state) - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToAnthHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 61b3bf9c..2386771d 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent return resToChatHandleReasoningDelta(evt, state) case "response.reasoning_summary_text.done": return nil - case "response.completed", "response.incomplete", "response.failed": + // response.done 是 Realtime/WS 与项目透传路径使用的终止别名; + // 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。 + case "response.completed", "response.done", "response.incomplete", "response.failed": return resToChatHandleCompleted(evt, state) default: return nil diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index f8c6b75f..0ff2cf49 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct { type ResponsesStreamEvent struct { Type string `json:"type"` - // response.created / response.completed / response.failed / response.incomplete + // response.created / response.completed / response.done / response.failed / response.incomplete Response *ResponsesResponse `json:"response,omitempty"` // response.output_item.added / response.output_item.done diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 074013c3..67d19720 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance } func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } if !stream { return ctx, func() {} } + return context.WithoutCancel(ctx), func() {} +} + +func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { return context.Background(), func() {} } diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go index c8803d39..39a7d3b0 100644 --- a/backend/internal/service/gateway_service_streaming_test.go +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +type upstreamContextTestKey string + func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 7, result.usage.OutputTokens) } + +func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) { + parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value")) + upstreamCtx, release := detachUpstreamContext(parent) + defer release() + + cancel() + + require.NoError(t, upstreamCtx.Err()) + require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key"))) +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 4396c15f..1129bf04 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -3,13 +3,16 @@ package service import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -18,6 +21,51 @@ import ( "github.com/tidwall/gjson" ) +type openAICompatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAICompatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +type openAICompatBlockingReadCloser struct { + data []byte + offset int + closed chan struct{} + closeOnce sync.Once +} + +func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser { + return &openAICompatBlockingReadCloser{ + data: data, + closed: make(chan struct{}), + } +} + +func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) { + if r.offset < len(r.data) { + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil + } + <-r.closed + return 0, io.EOF +} + +func (r *openAICompatBlockingReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.closed) + }) + return nil +} + func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { t.Parallel() @@ -228,3 +276,238 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon require.NotNil(t, result) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) } + +func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 15, got.result.Usage.InputTokens) + require.Equal(t, 6, got.result.Usage.OutputTokens) + require.Equal(t, 5, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 5822ae4c..3456cce0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai chat_completions buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai chat_completions buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ RequestID: requestID, @@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } chunks := apicompat.ResponsesEventToChatChunks(&event, state) - for _, chunk := range chunks { - sse, err := apicompat.ChatChunkToSSE(chunk) - if err != nil { - logger.L().Warn("openai chat_completions stream: failed to marshal chunk", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai chat_completions stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(chunks) > 0 { + if len(chunks) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } finalizeStream := func() (*OpenAIForwardResult, error) { - if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } } } // Send [DONE] sentinel - fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck - c.Writer.Flush() + if !clientDisconnected { + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during done flush", + zap.String("request_id", requestID), + ) + } + } + if !clientDisconnected { + c.Writer.Flush() + } return resultWithUsage(), nil } @@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // Determine keepalive interval keepaliveInterval := time.Duration(0) @@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } // No keepalive: fast synchronous path - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // With keepalive: goroutine + channel + select @@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { select { case ev, ok := <-events: if !ok { - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai chat_completions stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index 6846e03a..1236fb2c 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -1,13 +1,36 @@ package service import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) +type openAIChatFailingWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *openAIChatFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestNormalizeResponsesRequestServiceTier(t *testing.T) { t.Parallel() @@ -73,3 +96,238 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } + +func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) +} + +func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n") + upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody) + defer upstreamStream.Close() + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}}, + Body: upstreamStream, + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + type forwardResult struct { + result *OpenAIForwardResult + err error + } + resultCh := make(chan forwardResult, 1) + go func() { + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + resultCh <- forwardResult{result: result, err: err} + }() + + select { + case got := <-resultCh: + require.NoError(t, got.err) + require.NotNil(t, got.result) + require.Equal(t, 17, got.result.Usage.InputTokens) + require.Equal(t, 8, got.result.Usage.OutputTokens) + require.Equal(t, 6, got.result.Usage.CacheReadInputTokens) + require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`) + case <-time.After(time.Second): + require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open") + } +} + +func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := "data: [DONE]\n\n" + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) + require.Zero(t, result.Usage.InputTokens) + require.Zero(t, result.Usage.OutputTokens) +} + +func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 4e0ebb2e..9fd6f04c 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" @@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 6. Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) - - var finalResponse *apicompat.ResponsesResponse - var usage OpenAIUsage - acc := apicompat.NewBufferedResponseAccumulator() - - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue - } - payload := line[6:] - - var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(payload), &event); err != nil { - logger.L().Warn("openai messages buffered: failed to parse event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - - // Accumulate delta content for fallback when terminal output is empty. - acc.ProcessEvent(&event) - - // Terminal events carry the complete ResponsesResponse with output + usage. - if (event.Type == "response.completed" || event.Type == "response.done" || - event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil { - finalResponse = event.Response - if event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } - } - } - } - - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.L().Warn("openai messages buffered: read error", - zap.Error(err), - zap.String("request_id", requestID), - ) - } + finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID) + if err != nil { + return nil, err } if finalResponse == nil { @@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( }, nil } +func isOpenAICompatResponsesTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.incomplete", "response.failed": + return true + default: + return false + } +} + +func isOpenAICompatDoneSentinelLine(line string) bool { + payload, ok := extractOpenAISSEDataLine(line) + return ok && strings.TrimSpace(payload) == "[DONE]" +} + +func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( + resp *http.Response, + logPrefix string, + requestID string, +) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) { + acc := apicompat.NewBufferedResponseAccumulator() + var usage OpenAIUsage + if resp == nil || resp.Body == nil { + return nil, usage, acc, errors.New("upstream response body is nil") + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var timeoutCh <-chan time.Time + var timeoutTimer *time.Timer + resetTimeout := func() { + if streamInterval <= 0 { + return + } + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(streamInterval) + timeoutCh = timeoutTimer.C + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + timeoutTimer.Reset(streamInterval) + } + stopTimeout := func() { + if timeoutTimer == nil { + return + } + if !timeoutTimer.Stop() { + select { + case <-timeoutTimer.C: + default: + } + } + } + resetTimeout() + defer stopTimeout() + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + go func() { + defer close(events) + for scanner.Scan() { + select { + case events <- scanEvent{line: scanner.Text()}: + case <-done: + return + } + } + if err := scanner.Err(); err != nil { + select { + case events <- scanEvent{err: err}: + case <-done: + } + } + }() + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return nil, usage, acc, nil + } + resetTimeout() + if ev.err != nil { + if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) { + logger.L().Warn(logPrefix+": read error", + zap.Error(ev.err), + zap.String("request_id", requestID), + ) + } + return nil, usage, acc, ev.err + } + + payload, ok := extractOpenAISSEDataLine(ev.line) + if !ok || payload == "" { + continue + } + if strings.TrimSpace(payload) == "[DONE]" { + return nil, usage, acc, nil + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn(logPrefix+": failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + acc.ProcessEvent(&event) + + if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } + return event.Response, usage, acc, nil + } + + case <-timeoutCh: + _ = resp.Body.Close() + logger.L().Warn(logPrefix+": data interval timeout", + zap.String("request_id", requestID), + zap.Duration("interval", streamInterval), + ) + return nil, usage, acc, fmt.Errorf("stream data interval timeout") + } + } +} + // handleAnthropicStreamingResponse reads Responses SSE events from upstream, // converts each to Anthropic SSE events, and writes them to the client. // When StreamKeepaliveInterval is configured, it uses a goroutine + channel @@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( var usage OpenAIUsage var firstTokenMs *int firstChunk := true + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ @@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // processDataLine handles a single "data: ..." SSE line from upstream. - // Returns (clientDisconnected bool). processDataLine := func(payload string) bool { if firstChunk { firstChunk = false @@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( return false } - // Extract usage from completion events - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && - event.Response != nil && event.Response.Usage != nil { - usage = OpenAIUsage{ - InputTokens: event.Response.Usage.InputTokens, - OutputTokens: event.Response.Usage.OutputTokens, - } - if event.Response.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens - } + // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 + isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) + if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } // Convert to Anthropic events events := apicompat.ResponsesEventToAnthropicEvents(&event, state) - for _, evt := range events { - sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) - if err != nil { - logger.L().Warn("openai messages stream: failed to marshal event", - zap.Error(err), - zap.String("request_id", requestID), - ) - continue - } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { - logger.L().Info("openai messages stream: client disconnected", - zap.String("request_id", requestID), - ) - return true + if !clientDisconnected { + for _, evt := range events { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + logger.L().Warn("openai messages stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing", + zap.String("request_id", requestID), + ) + break + } } } - if len(events) > 0 { + if len(events) > 0 && !clientDisconnected { c.Writer.Flush() } - return false + return isTerminalEvent } // finalizeStream sends any remaining Anthropic events and returns the result. finalizeStream := func() (*OpenAIForwardResult, error) { - if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected { for _, evt := range finalEvents { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Info("openai messages stream: client disconnected during final flush", + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() } - c.Writer.Flush() } return resultWithUsage(), nil } @@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( ) } } + missingTerminalErr := func() (*OpenAIForwardResult, error) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } // ── Determine keepalive interval ── keepaliveInterval := time.Duration(0) @@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } // ── No keepalive: fast synchronous path (no goroutine overhead) ── - if keepaliveInterval <= 0 { + if streamInterval <= 0 && keepaliveInterval <= 0 { for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } } - handleScanErr(scanner.Err()) - return finalizeStream() + if err := scanner.Err(); err != nil { + handleScanErr(err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + return missingTerminalErr() } // ── With keepalive: goroutine + channel + select ── @@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } events := make(chan scanEvent, 16) done := make(chan struct{}) + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) sendEvent := func(ev scanEvent) bool { select { case events <- ev: @@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( }() defer close(done) - keepaliveTicker := time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } lastDataAt := time.Now() for { @@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( case ev, ok := <-events: if !ok { // Upstream closed - return finalizeStream() + return missingTerminalErr() } if ev.err != nil { handleScanErr(ev.err) - return finalizeStream() + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err) } lastDataAt = time.Now() line := ev.line - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + payload, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - if processDataLine(line[6:]) { - return resultWithUsage(), nil + if strings.TrimSpace(payload) == "[DONE]" { + return missingTerminalErr() + } + if processDataLine(payload) { + return finalizeStream() } - case <-keepaliveTicker.C: + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") + } + logger.L().Warn("openai messages stream: data interval timeout", + zap.String("request_id", requestID), + zap.String("model", originalModel), + zap.Duration("interval", streamInterval), + ) + return resultWithUsage(), fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } @@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( logger.L().Info("openai messages stream: client disconnected during keepalive", zap.String("request_id", requestID), ) - return resultWithUsage(), nil + clientDisconnected = true + continue } c.Writer.Flush() } @@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string }, }) } + +func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage { + if usage == nil { + return OpenAIUsage{} + } + result := OpenAIUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + } + if usage.InputTokensDetails != nil { + result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens + } + return result +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ed69730c..d1d73586 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco httpInvalidEncryptedContentRetryTried := false for { // Build upstream request - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) releaseUpstreamCtx() if err != nil { @@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) releaseUpstreamCtx() if err != nil { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 049ffdd8..87a05b14 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) } +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te require.Contains(t, string(upstream.lastBody), `"stream":true`) } +func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + reqCtx, cancel := context.WithCancel(context.Background()) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + cancel() + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(reqCtx, c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { gin.SetMode(gin.TestMode)