fix: drain OpenAI compat streams for usage

This commit is contained in:
shaw 2026-05-03 17:11:27 +08:00
parent b2bdba78dd
commit 72d5ee4cd1
13 changed files with 1141 additions and 205 deletions

View File

@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type) assert.Equal(t, "message_stop", events[1].Type)
} }
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
assert.Equal(t, 12, events[0].Usage.InputTokens)
assert.Equal(t, 4, events[0].Usage.OutputTokens)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToAnthropicState()
state.Model = "gpt-4o"
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
},
}, state)
require.Len(t, events, 2)
assert.Equal(t, "message_delta", events[0].Type)
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
assert.Equal(t, "message_stop", events[1].Type)
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
}
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
state := NewResponsesEventToAnthropicState() state := NewResponsesEventToAnthropicState()
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{

View File

@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
} }
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.done",
Response: &ResponsesResponse{
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState() state := NewResponsesEventToChatState()
state.Model = "gpt-4o" state.Model = "gpt-4o"

View File

@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
return resToAnthHandleReasoningDelta(evt, state) return resToAnthHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return resToAnthHandleBlockDone(state) return resToAnthHandleBlockDone(state)
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToAnthHandleCompleted(evt, state) return resToAnthHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleReasoningDelta(evt, state) return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done": case "response.reasoning_summary_text.done":
return nil return nil
case "response.completed", "response.incomplete", "response.failed": // response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
case "response.completed", "response.done", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state) return resToChatHandleCompleted(evt, state)
default: default:
return nil return nil

View File

@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
type ResponsesStreamEvent struct { type ResponsesStreamEvent struct {
Type string `json:"type"` Type string `json:"type"`
// response.created / response.completed / response.failed / response.incomplete // response.created / response.completed / response.done / response.failed / response.incomplete
Response *ResponsesResponse `json:"response,omitempty"` Response *ResponsesResponse `json:"response,omitempty"`
// response.output_item.added / response.output_item.done // response.output_item.added / response.output_item.done

View File

@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
} }
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if ctx == nil {
return context.Background(), func() {}
}
if !stream { if !stream {
return ctx, func() {} return ctx, func() {}
} }
return context.WithoutCancel(ctx), func() {}
}
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil { if ctx == nil {
return context.Background(), func() {} return context.Background(), func() {}
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type upstreamContextTestKey string
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
require.Equal(t, 3, result.usage.InputTokens) require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens) require.Equal(t, 7, result.usage.OutputTokens)
} }
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
upstreamCtx, release := detachUpstreamContext(parent)
defer release()
cancel()
require.NoError(t, upstreamCtx.Err())
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
}

View File

@ -3,13 +3,16 @@ package service
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -18,6 +21,51 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAICompatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
type openAICompatBlockingReadCloser struct {
data []byte
offset int
closed chan struct{}
closeOnce sync.Once
}
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
return &openAICompatBlockingReadCloser{
data: data,
closed: make(chan struct{}),
}
}
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
if r.offset < len(r.data) {
n := copy(p, r.data[r.offset:])
r.offset += n
return n, nil
}
<-r.closed
return 0, io.EOF
}
func (r *openAICompatBlockingReadCloser) Close() error {
r.closeOnce.Do(func() {
close(r.closed)
})
return nil
}
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
t.Parallel() t.Parallel()
@ -228,3 +276,238 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
} }
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 9, result.Usage.InputTokens)
require.Equal(t, 4, result.Usage.OutputTokens)
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
}
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer upstreamStream.Close()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer upstreamStream.Close()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 15, got.result.Usage.InputTokens)
require.Equal(t, 6, got.result.Usage.OutputTokens)
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: requestID, RequestID: requestID,
@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
chunks := apicompat.ResponsesEventToChatChunks(&event, state) chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks { if !clientDisconnected {
sse, err := apicompat.ChatChunkToSSE(chunk) for _, chunk := range chunks {
if err != nil { sse, err := apicompat.ChatChunkToSSE(chunk)
logger.L().Warn("openai chat_completions stream: failed to marshal chunk", if err != nil {
zap.Error(err), logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai chat_completions stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(chunks) > 0 { if len(chunks) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks { for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk) sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
} }
} }
// Send [DONE] sentinel // Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck if !clientDisconnected {
c.Writer.Flush() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
}
if !clientDisconnected {
c.Writer.Flush()
}
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// Determine keepalive interval // Determine keepalive interval
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
// No keepalive: fast synchronous path // No keepalive: fast synchronous path
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// With keepalive: goroutine + channel + select // With keepalive: goroutine + channel + select
@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
select { select {
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai chat_completions stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
logger.L().Info("openai chat_completions stream: client disconnected during keepalive", logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }

View File

@ -1,13 +1,36 @@
package service package service
import ( import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type openAIChatFailingWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed: client disconnected")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func TestNormalizeResponsesRequestServiceTier(t *testing.T) { func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel() t.Parallel()
@ -73,3 +96,238 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Empty(t, tier) require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists()) require.False(t, gjson.GetBytes(body, "service_tier").Exists())
} }
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
"",
`data: {"type":"response.output_text.delta","delta":"ok"}`,
"",
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 5, result.Usage.OutputTokens)
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
}
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer upstreamStream.Close()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
defer upstreamStream.Close()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
Body: upstreamStream,
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
type forwardResult struct {
result *OpenAIForwardResult
err error
}
resultCh := make(chan forwardResult, 1)
go func() {
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
resultCh <- forwardResult{result: result, err: err}
}()
select {
case got := <-resultCh:
require.NoError(t, got.err)
require.NotNil(t, got.result)
require.Equal(t, 17, got.result.Usage.InputTokens)
require.Equal(t, 8, got.result.Usage.OutputTokens)
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
case <-time.After(time.Second):
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
}
}
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := "data: [DONE]\n\n"
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
require.Zero(t, result.Usage.InputTokens)
require.Zero(t, result.Usage.OutputTokens)
}
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
c.Request.Header.Set("Content-Type", "application/json")
cancel()
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{httpUpstream: upstream}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 6. Build upstream request // 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err) return nil, fmt.Errorf("build upstream request: %w", err)
} }
@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id") requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body) finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
maxLineSize := defaultMaxLineSize if err != nil {
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { return nil, err
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
acc := apicompat.NewBufferedResponseAccumulator()
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Accumulate delta content for fallback when terminal output is empty.
acc.ProcessEvent(&event)
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.done" ||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} }
if finalResponse == nil { if finalResponse == nil {
@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
}, nil }, nil
} }
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.incomplete", "response.failed":
return true
default:
return false
}
}
func isOpenAICompatDoneSentinelLine(line string) bool {
payload, ok := extractOpenAISSEDataLine(line)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
resp *http.Response,
logPrefix string,
requestID string,
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
acc := apicompat.NewBufferedResponseAccumulator()
var usage OpenAIUsage
if resp == nil || resp.Body == nil {
return nil, usage, acc, errors.New("upstream response body is nil")
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var timeoutCh <-chan time.Time
var timeoutTimer *time.Timer
resetTimeout := func() {
if streamInterval <= 0 {
return
}
if timeoutTimer == nil {
timeoutTimer = time.NewTimer(streamInterval)
timeoutCh = timeoutTimer.C
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
timeoutTimer.Reset(streamInterval)
}
stopTimeout := func() {
if timeoutTimer == nil {
return
}
if !timeoutTimer.Stop() {
select {
case <-timeoutTimer.C:
default:
}
}
}
resetTimeout()
defer stopTimeout()
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
go func() {
defer close(events)
for scanner.Scan() {
select {
case events <- scanEvent{line: scanner.Text()}:
case <-done:
return
}
}
if err := scanner.Err(); err != nil {
select {
case events <- scanEvent{err: err}:
case <-done:
}
}
}()
defer close(done)
for {
select {
case ev, ok := <-events:
if !ok {
return nil, usage, acc, nil
}
resetTimeout()
if ev.err != nil {
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
logger.L().Warn(logPrefix+": read error",
zap.Error(ev.err),
zap.String("request_id", requestID),
)
}
return nil, usage, acc, ev.err
}
payload, ok := extractOpenAISSEDataLine(ev.line)
if !ok || payload == "" {
continue
}
if strings.TrimSpace(payload) == "[DONE]" {
return nil, usage, acc, nil
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn(logPrefix+": failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
acc.ProcessEvent(&event)
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
if event.Response.Usage != nil {
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
}
return event.Response, usage, acc, nil
}
case <-timeoutCh:
_ = resp.Body.Close()
logger.L().Warn(logPrefix+": data interval timeout",
zap.String("request_id", requestID),
zap.Duration("interval", streamInterval),
)
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
}
}
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream, // handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client. // converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel // When StreamKeepaliveInterval is configured, it uses a goroutine + channel
@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
var usage OpenAIUsage var usage OpenAIUsage
var firstTokenMs *int var firstTokenMs *int
firstChunk := true firstChunk := true
clientDisconnected := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// resultWithUsage builds the final result snapshot. // resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult { resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{ return &OpenAIForwardResult{
@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// processDataLine handles a single "data: ..." SSE line from upstream. // processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool { processDataLine := func(payload string) bool {
if firstChunk { if firstChunk {
firstChunk = false firstChunk = false
@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
return false return false
} }
// Extract usage from completion events // 仅按兼容转换器支持的终止事件提取 usage避免无意扩大事件语义。
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
event.Response != nil && event.Response.Usage != nil { if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
} }
// Convert to Anthropic events // Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state) events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events { if !clientDisconnected {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) for _, evt := range events {
if err != nil { sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
logger.L().Warn("openai messages stream: failed to marshal event", if err != nil {
zap.Error(err), logger.L().Warn("openai messages stream: failed to marshal event",
zap.String("request_id", requestID), zap.Error(err),
) zap.String("request_id", requestID),
continue )
} continue
if _, err := fmt.Fprint(c.Writer, sse); err != nil { }
logger.L().Info("openai messages stream: client disconnected", if _, err := fmt.Fprint(c.Writer, sse); err != nil {
zap.String("request_id", requestID), clientDisconnected = true
) logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
return true zap.String("request_id", requestID),
)
break
}
} }
} }
if len(events) > 0 { if len(events) > 0 && !clientDisconnected {
c.Writer.Flush() c.Writer.Flush()
} }
return false return isTerminalEvent
} }
// finalizeStream sends any remaining Anthropic events and returns the result. // finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) { finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
for _, evt := range finalEvents { for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil { if err != nil {
continue continue
} }
fmt.Fprint(c.Writer, sse) //nolint:errcheck if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai messages stream: client disconnected during final flush",
zap.String("request_id", requestID),
)
break
}
}
if !clientDisconnected {
c.Writer.Flush()
} }
c.Writer.Flush()
} }
return resultWithUsage(), nil return resultWithUsage(), nil
} }
@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
) )
} }
} }
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
// ── Determine keepalive interval ── // ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0) keepaliveInterval := time.Duration(0)
@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
// ── No keepalive: fast synchronous path (no goroutine overhead) ── // ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 { if streamInterval <= 0 && keepaliveInterval <= 0 {
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
} }
handleScanErr(scanner.Err()) if err := scanner.Err(); err != nil {
return finalizeStream() handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
return missingTerminalErr()
} }
// ── With keepalive: goroutine + channel + select ── // ── With keepalive: goroutine + channel + select ──
@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
} }
events := make(chan scanEvent, 16) events := make(chan scanEvent, 16)
done := make(chan struct{}) done := make(chan struct{})
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
sendEvent := func(ev scanEvent) bool { sendEvent := func(ev scanEvent) bool {
select { select {
case events <- ev: case events <- ev:
@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
go func() { go func() {
defer close(events) defer close(events)
for scanner.Scan() { for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) { if !sendEvent(scanEvent{line: scanner.Text()}) {
return return
} }
@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
}() }()
defer close(done) defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval) var keepaliveTicker *time.Ticker
defer keepaliveTicker.Stop() if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now() lastDataAt := time.Now()
for { for {
@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
// Upstream closed // Upstream closed
return finalizeStream() return missingTerminalErr()
} }
if ev.err != nil { if ev.err != nil {
handleScanErr(ev.err) handleScanErr(ev.err)
return finalizeStream() return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
} }
lastDataAt = time.Now() lastDataAt = time.Now()
line := ev.line line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
if processDataLine(line[6:]) { if strings.TrimSpace(payload) == "[DONE]" {
return resultWithUsage(), nil return missingTerminalErr()
}
if processDataLine(payload) {
return finalizeStream()
} }
case <-keepaliveTicker.C: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.L().Warn("openai messages stream: data interval timeout",
zap.String("request_id", requestID),
zap.String("model", originalModel),
zap.Duration("interval", streamInterval),
)
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
logger.L().Info("openai messages stream: client disconnected during keepalive", logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID), zap.String("request_id", requestID),
) )
return resultWithUsage(), nil clientDisconnected = true
continue
} }
c.Writer.Flush() c.Writer.Flush()
} }
@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
}, },
}) })
} }
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
if usage == nil {
return OpenAIUsage{}
}
result := OpenAIUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
}
if usage.InputTokensDetails != nil {
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
}
return result
}

View File

@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
httpInvalidEncryptedContentRetryTried := false httpInvalidEncryptedContentRetryTried := false
for { for {
// Build upstream request // Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {
@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err return nil, err
} }
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx() releaseUpstreamCtx()
if err != nil { if err != nil {

View File

@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`) require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
} }
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
require.Contains(t, string(upstream.lastBody), `"stream":true`) require.Contains(t, string(upstream.lastBody), `"stream":true`)
} }
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
reqCtx, cancel := context.WithCancel(context.Background())
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
cancel()
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
"",
"data: [DONE]",
"",
}, "\n"))),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(reqCtx, c, account, originalBody)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)