diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index 9535395f..3be765a2 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "go.uber.org/zap" ) @@ -97,6 +98,13 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( return nil, policyErr } upstreamBody = updatedBody + if clientStream { + var usageErr error + upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody) + if usageErr != nil { + return nil, fmt.Errorf("enable stream usage: %w", usageErr) + } + } logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion", zap.Int64("account_id", account.ID), @@ -121,7 +129,9 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( } targetURL := buildOpenAIChatCompletionsURL(validatedURL) - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } @@ -219,8 +229,8 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( // 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。 // // usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。 -// 本函数不检查客户端的请求 flag——上游会自行处理,我们仅在上游响应 -// chunk 中出现 usage 时提取。 +// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage, +// 让级联代理或下游计费系统也能拿到完整用量。 func (s *OpenAIGatewayService) streamRawChatCompletions( c *gin.Context, resp *http.Response, @@ -251,36 +261,41 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( var usage OpenAIUsage var firstTokenMs *int + clientDisconnected := false for scanner.Scan() { line := scanner.Text() - // Direct passthrough: write each line + blank line separator - if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { - logger.L().Debug("openai chat_completions raw: client write failed", - zap.Error(werr), - zap.String("request_id", requestID), - ) - break + if payload, ok := extractOpenAISSEDataLine(line); ok { + trimmedPayload := strings.TrimSpace(payload) + if trimmedPayload != "[DONE]" { + usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload) + if u := extractCCStreamUsage(payload); u != nil { + usage = *u + } + if firstTokenMs == nil && !usageOnlyChunk { + elapsed := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &elapsed + } + } + } + + if !clientDisconnected { + if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + } } if line == "" { - c.Writer.Flush() + if !clientDisconnected { + c.Writer.Flush() + } continue } - c.Writer.Flush() - - // Track first token timing on first non-empty data line - if firstTokenMs == nil && strings.HasPrefix(line, "data: ") && line != "data: [DONE]" { - elapsed := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &elapsed - } - - // Extract usage from any chunk that carries it (CC streams typically put - // usage in the final chunk before [DONE], but may also appear elsewhere). - if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" { - payload := line[6:] - if u := extractCCStreamUsage(payload); u != nil { - usage = *u - } + if !clientDisconnected { + c.Writer.Flush() } } @@ -307,6 +322,45 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( }, nil } +// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。 +// usage 也会继续向下游透传,支持级联代理和下游计费系统。 +func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) { + updated, err := sjson.SetBytes(body, "stream_options.include_usage", true) + if err != nil { + return body, err + } + return updated, nil +} + +func isOpenAIChatUsageOnlyStreamChunk(payload string) bool { + if strings.TrimSpace(payload) == "" { + return false + } + if !gjson.Get(payload, "usage").Exists() { + return false + } + choices := gjson.Get(payload, "choices") + return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0 +} + +// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。 +// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时), +// 但上游可能在多个 chunk 中重复——总是用最新值。 +func extractCCStreamUsage(payload string) *OpenAIUsage { + usageResult := gjson.Get(payload, "usage") + if !usageResult.Exists() || !usageResult.IsObject() { + return nil + } + u := OpenAIUsage{ + InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()), + OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()), + } + if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() { + u.CacheReadInputTokens = int(cached.Int()) + } + return &u +} + // bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。 func (s *OpenAIGatewayService) bufferRawChatCompletions( c *gin.Context, @@ -320,9 +374,11 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 32<<20)) + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + } return nil, fmt.Errorf("read upstream body: %w", err) } @@ -362,24 +418,6 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions( }, nil } -// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。 -// CC 协议中 usage 仅出现在末尾 chunk(且仅当客户端请求 stream_options.include_usage -// 时),但上游可能在多个 chunk 中重复——总是用最新值。 -func extractCCStreamUsage(payload string) *OpenAIUsage { - usageResult := gjson.Get(payload, "usage") - if !usageResult.Exists() || !usageResult.IsObject() { - return nil - } - u := OpenAIUsage{ - InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()), - OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()), - } - if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() { - u.CacheReadInputTokens = int(cached.Int()) - } - return &u -} - // buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。 // // - base 已是 /chat/completions:原样返回 diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go index 01013837..1be07fd7 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go @@ -3,9 +3,19 @@ package service import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestBuildOpenAIChatCompletionsURL(t *testing.T) { @@ -65,3 +75,186 @@ func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) { }) } } + +func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) + require.Contains(t, rec.Body.String(), `"usage"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 17, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 6, result.Usage.CacheReadInputTokens) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) +} + +func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) { + gin.SetMode(gin.TestMode) + + reqCtx, cancel := context.WithCancel(context.Background()) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx) + c.Request.Header.Set("Content-Type", "application/json") + cancel() + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + require.NoError(t, upstream.lastReq.Context().Err()) +} + +func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) { + t.Parallel() + + require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`)) + require.False(t, isOpenAIChatUsageOnlyStreamChunk(``)) +} + +func TestEnsureOpenAIChatStreamUsage(t *testing.T) { + t.Parallel() + + body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`)) + require.NoError(t, err) + require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool()) + + body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`)) + require.NoError(t, err) + require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool()) +} + +func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader("toolong")), + } + svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()} + svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now()) + require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge) + require.Nil(t, result) + require.Equal(t, http.StatusBadGateway, rec.Code) +} + +func rawChatCompletionsTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +func rawChatCompletionsTestAccount() *Account { + return &Account{ + ID: 101, + Name: "raw-openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "http://upstream.example", + }, + } +}