fix: drain OpenAI compat streams for usage
This commit is contained in:
parent
b2bdba78dd
commit
72d5ee4cd1
@ -434,6 +434,45 @@ func TestStreamingTextOnly(t *testing.T) {
|
|||||||
assert.Equal(t, "message_stop", events[1].Type)
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||||
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
assert.Equal(t, "message_delta", events[0].Type)
|
||||||
|
assert.Equal(t, "end_turn", events[0].Delta.StopReason)
|
||||||
|
assert.Equal(t, 12, events[0].Usage.InputTokens)
|
||||||
|
assert.Equal(t, 4, events[0].Usage.OutputTokens)
|
||||||
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
|
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||||
|
state := NewResponsesEventToAnthropicState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
|
||||||
|
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 12, OutputTokens: 4},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, events, 2)
|
||||||
|
assert.Equal(t, "message_delta", events[0].Type)
|
||||||
|
assert.Equal(t, "max_tokens", events[0].Delta.StopReason)
|
||||||
|
assert.Equal(t, "message_stop", events[1].Type)
|
||||||
|
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
|
||||||
state := NewResponsesEventToAnthropicState()
|
state := NewResponsesEventToAnthropicState()
|
||||||
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||||
|
|||||||
@ -720,6 +720,49 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
|||||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.done",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "incomplete",
|
||||||
|
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
|
||||||
|
Usage: &ResponsesUsage{InputTokens: 13, OutputTokens: 7},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].FinishReason)
|
||||||
|
assert.Equal(t, "length", *chunks[0].Choices[0].FinishReason)
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
assert.Equal(t, 13, chunks[1].Usage.PromptTokens)
|
||||||
|
assert.Equal(t, 7, chunks[1].Usage.CompletionTokens)
|
||||||
|
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
|
||||||
state := NewResponsesEventToChatState()
|
state := NewResponsesEventToChatState()
|
||||||
state.Model = "gpt-4o"
|
state.Model = "gpt-4o"
|
||||||
|
|||||||
@ -212,7 +212,9 @@ func ResponsesEventToAnthropicEvents(
|
|||||||
return resToAnthHandleReasoningDelta(evt, state)
|
return resToAnthHandleReasoningDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.done":
|
case "response.reasoning_summary_text.done":
|
||||||
return resToAnthHandleBlockDone(state)
|
return resToAnthHandleBlockDone(state)
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||||
|
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
return resToAnthHandleCompleted(evt, state)
|
return resToAnthHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -160,7 +160,9 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
|||||||
return resToChatHandleReasoningDelta(evt, state)
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.done":
|
case "response.reasoning_summary_text.done":
|
||||||
return nil
|
return nil
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
// response.done 是 Realtime/WS 与项目透传路径使用的终止别名;
|
||||||
|
// 普通 Responses HTTP SSE 的公开终止事件仍以 response.completed 为主。
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
return resToChatHandleCompleted(evt, state)
|
return resToChatHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -314,7 +314,7 @@ type ResponsesOutputTokensDetails struct {
|
|||||||
type ResponsesStreamEvent struct {
|
type ResponsesStreamEvent struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
// response.created / response.completed / response.failed / response.incomplete
|
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||||
Response *ResponsesResponse `json:"response,omitempty"`
|
Response *ResponsesResponse `json:"response,omitempty"`
|
||||||
|
|
||||||
// response.output_item.added / response.output_item.done
|
// response.output_item.added / response.output_item.done
|
||||||
|
|||||||
@ -8174,9 +8174,16 @@ func detachedBillingContext(ctx context.Context) (context.Context, context.Cance
|
|||||||
}
|
}
|
||||||
|
|
||||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||||
|
if ctx == nil {
|
||||||
|
return context.Background(), func() {}
|
||||||
|
}
|
||||||
if !stream {
|
if !stream {
|
||||||
return ctx, func() {}
|
return ctx, func() {}
|
||||||
}
|
}
|
||||||
|
return context.WithoutCancel(ctx), func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return context.Background(), func() {}
|
return context.Background(), func() {}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type upstreamContextTestKey string
|
||||||
|
|
||||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
@ -50,3 +52,14 @@ func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testi
|
|||||||
require.Equal(t, 3, result.usage.InputTokens)
|
require.Equal(t, 3, result.usage.InputTokens)
|
||||||
require.Equal(t, 7, result.usage.OutputTokens)
|
require.Equal(t, 7, result.usage.OutputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDetachUpstreamContextIgnoresClientCancel(t *testing.T) {
|
||||||
|
parent, cancel := context.WithCancel(context.WithValue(context.Background(), upstreamContextTestKey("test-key"), "test-value"))
|
||||||
|
upstreamCtx, release := detachUpstreamContext(parent)
|
||||||
|
defer release()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.NoError(t, upstreamCtx.Err())
|
||||||
|
require.Equal(t, "test-value", upstreamCtx.Value(upstreamContextTestKey("test-key")))
|
||||||
|
}
|
||||||
|
|||||||
@ -3,13 +3,16 @@ package service
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -18,6 +21,51 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type openAICompatFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *openAICompatFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAICompatBlockingReadCloser struct {
|
||||||
|
data []byte
|
||||||
|
offset int
|
||||||
|
closed chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAICompatBlockingReadCloser(data []byte) *openAICompatBlockingReadCloser {
|
||||||
|
return &openAICompatBlockingReadCloser{
|
||||||
|
data: data,
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAICompatBlockingReadCloser) Read(p []byte) (int, error) {
|
||||||
|
if r.offset < len(r.data) {
|
||||||
|
n := copy(p, r.data[r.offset:])
|
||||||
|
r.offset += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-r.closed
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAICompatBlockingReadCloser) Close() error {
|
||||||
|
r.closeOnce.Do(func() {
|
||||||
|
close(r.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -228,3 +276,238 @@ func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateCon
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_disconnect"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 9, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer upstreamStream.Close()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsAnthropic should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer upstreamStream.Close()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_buffered_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 15, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
|
||||||
|
require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsAnthropic buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := "data: [DONE]\n\n"
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_missing_terminal"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Zero(t, result.Usage.InputTokens)
|
||||||
|
require.Zero(t, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsAnthropic_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsAnthropic(reqCtx, c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -189,7 +190,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6. Build upstream request
|
// 6. Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, promptCacheKey, false)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
}
|
}
|
||||||
@ -348,59 +351,9 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
|||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai chat_completions buffered", requestID)
|
||||||
maxLineSize := defaultMaxLineSize
|
if err != nil {
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
return nil, err
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
var finalResponse *apicompat.ResponsesResponse
|
|
||||||
var usage OpenAIUsage
|
|
||||||
acc := apicompat.NewBufferedResponseAccumulator()
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := line[6:]
|
|
||||||
|
|
||||||
var event apicompat.ResponsesStreamEvent
|
|
||||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
||||||
logger.L().Warn("openai chat_completions buffered: failed to parse event",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accumulate delta content for fallback when terminal output is empty.
|
|
||||||
acc.ProcessEvent(&event)
|
|
||||||
|
|
||||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
|
||||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
||||||
event.Response != nil {
|
|
||||||
finalResponse = event.Response
|
|
||||||
if event.Response.Usage != nil {
|
|
||||||
usage = OpenAIUsage{
|
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
logger.L().Warn("openai chat_completions buffered: read error",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if finalResponse == nil {
|
if finalResponse == nil {
|
||||||
@ -459,6 +412,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
firstChunk := true
|
firstChunk := true
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@ -467,6 +421,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
resultWithUsage := func() *OpenAIForwardResult {
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
return &OpenAIForwardResult{
|
return &OpenAIForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
@ -496,54 +464,66 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract usage from completion events
|
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||||
event.Response != nil && event.Response.Usage != nil {
|
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||||
usage = OpenAIUsage{
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
|
||||||
for _, chunk := range chunks {
|
if !clientDisconnected {
|
||||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
for _, chunk := range chunks {
|
||||||
if err != nil {
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
if err != nil {
|
||||||
zap.Error(err),
|
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
|
||||||
zap.String("request_id", requestID),
|
zap.Error(err),
|
||||||
)
|
zap.String("request_id", requestID),
|
||||||
continue
|
)
|
||||||
}
|
continue
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
}
|
||||||
logger.L().Info("openai chat_completions stream: client disconnected",
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
zap.String("request_id", requestID),
|
clientDisconnected = true
|
||||||
)
|
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
|
||||||
return true
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(chunks) > 0 {
|
if len(chunks) > 0 && !clientDisconnected {
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
return false
|
return isTerminalEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
|
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
|
||||||
for _, chunk := range finalChunks {
|
for _, chunk := range finalChunks {
|
||||||
sse, err := apicompat.ChatChunkToSSE(chunk)
|
sse, err := apicompat.ChatChunkToSSE(chunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Send [DONE] sentinel
|
// Send [DONE] sentinel
|
||||||
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
|
if !clientDisconnected {
|
||||||
c.Writer.Flush()
|
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clientDisconnected {
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -555,6 +535,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
// Determine keepalive interval
|
// Determine keepalive interval
|
||||||
keepaliveInterval := time.Duration(0)
|
keepaliveInterval := time.Duration(0)
|
||||||
@ -563,18 +546,25 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No keepalive: fast synchronous path
|
// No keepalive: fast synchronous path
|
||||||
if keepaliveInterval <= 0 {
|
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handleScanErr(scanner.Err())
|
if err := scanner.Err(); err != nil {
|
||||||
return finalizeStream()
|
handleScanErr(err)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
|
}
|
||||||
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// With keepalive: goroutine + channel + select
|
// With keepalive: goroutine + channel + select
|
||||||
@ -584,6 +574,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}
|
}
|
||||||
events := make(chan scanEvent, 16)
|
events := make(chan scanEvent, 16)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
sendEvent := func(ev scanEvent) bool {
|
sendEvent := func(ev scanEvent) bool {
|
||||||
select {
|
select {
|
||||||
case events <- ev:
|
case events <- ev:
|
||||||
@ -595,6 +587,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(events)
|
defer close(events)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -605,30 +598,59 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
}()
|
}()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
var keepaliveTicker *time.Ticker
|
||||||
defer keepaliveTicker.Stop()
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return finalizeStream()
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
handleScanErr(ev.err)
|
handleScanErr(ev.err)
|
||||||
return finalizeStream()
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
}
|
}
|
||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-keepaliveTicker.C:
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||||
|
}
|
||||||
|
logger.L().Warn("openai chat_completions stream: data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.String("model", originalModel),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -637,7 +659,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
|
||||||
zap.String("request_id", requestID),
|
zap.String("request_id", requestID),
|
||||||
)
|
)
|
||||||
return resultWithUsage(), nil
|
clientDisconnected = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,13 +1,36 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type openAIChatFailingWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *openAIChatFailingWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed: client disconnected")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -73,3 +96,238 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
|||||||
require.Empty(t, tier)
|
require.Empty(t, tier)
|
||||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.output_text.delta","delta":"ok"}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":11,"output_tokens":5,"total_tokens":16,"input_tokens_details":{"cached_tokens":4}}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_disconnect"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, 11, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer upstreamStream.Close()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsChatCompletions should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := []byte(`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}` + "\n\n")
|
||||||
|
upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
|
||||||
|
defer upstreamStream.Close()
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_terminal_no_close"}},
|
||||||
|
Body: upstreamStream,
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardResult struct {
|
||||||
|
result *OpenAIForwardResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan forwardResult, 1)
|
||||||
|
go func() {
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
resultCh <- forwardResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-resultCh:
|
||||||
|
require.NoError(t, got.err)
|
||||||
|
require.NotNil(t, got.result)
|
||||||
|
require.Equal(t, 17, got.result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, got.result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
|
||||||
|
require.Contains(t, rec.Body.String(), `"finish_reason":"stop"`)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
require.Fail(t, "ForwardAsChatCompletions buffered response should return after terminal usage event even if upstream keeps the connection open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
upstreamBody := "data: [DONE]\n\n"
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_missing_terminal"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Zero(t, result.Usage.InputTokens)
|
||||||
|
require.Zero(t, result.Usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAsChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
upstreamBody := strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
"chatgpt_account_id": "chatgpt-acc",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardAsChatCompletions(reqCtx, c, account, body, "", "gpt-5.1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
@ -163,7 +164,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6. Build upstream request
|
// 6. Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, isStream, promptCacheKey, false)
|
||||||
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||||
}
|
}
|
||||||
@ -296,61 +299,9 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
finalResponse, usage, acc, err := s.readOpenAICompatBufferedTerminal(resp, "openai messages buffered", requestID)
|
||||||
maxLineSize := defaultMaxLineSize
|
if err != nil {
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
return nil, err
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
var finalResponse *apicompat.ResponsesResponse
|
|
||||||
var usage OpenAIUsage
|
|
||||||
acc := apicompat.NewBufferedResponseAccumulator()
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
|
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := line[6:]
|
|
||||||
|
|
||||||
var event apicompat.ResponsesStreamEvent
|
|
||||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
|
||||||
logger.L().Warn("openai messages buffered: failed to parse event",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accumulate delta content for fallback when terminal output is empty.
|
|
||||||
acc.ProcessEvent(&event)
|
|
||||||
|
|
||||||
// Terminal events carry the complete ResponsesResponse with output + usage.
|
|
||||||
if (event.Type == "response.completed" || event.Type == "response.done" ||
|
|
||||||
event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
|
||||||
event.Response != nil {
|
|
||||||
finalResponse = event.Response
|
|
||||||
if event.Response.Usage != nil {
|
|
||||||
usage = OpenAIUsage{
|
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
logger.L().Warn("openai messages buffered: read error",
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("request_id", requestID),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if finalResponse == nil {
|
if finalResponse == nil {
|
||||||
@ -380,6 +331,153 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAICompatResponsesTerminalEvent(eventType string) bool {
|
||||||
|
switch strings.TrimSpace(eventType) {
|
||||||
|
case "response.completed", "response.done", "response.incomplete", "response.failed":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAICompatDoneSentinelLine(line string) bool {
|
||||||
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
|
||||||
|
resp *http.Response,
|
||||||
|
logPrefix string,
|
||||||
|
requestID string,
|
||||||
|
) (*apicompat.ResponsesResponse, OpenAIUsage, *apicompat.BufferedResponseAccumulator, error) {
|
||||||
|
acc := apicompat.NewBufferedResponseAccumulator()
|
||||||
|
var usage OpenAIUsage
|
||||||
|
if resp == nil || resp.Body == nil {
|
||||||
|
return nil, usage, acc, errors.New("upstream response body is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var timeoutCh <-chan time.Time
|
||||||
|
var timeoutTimer *time.Timer
|
||||||
|
resetTimeout := func() {
|
||||||
|
if streamInterval <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if timeoutTimer == nil {
|
||||||
|
timeoutTimer = time.NewTimer(streamInterval)
|
||||||
|
timeoutCh = timeoutTimer.C
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !timeoutTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timeoutTimer.Reset(streamInterval)
|
||||||
|
}
|
||||||
|
stopTimeout := func() {
|
||||||
|
if timeoutTimer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !timeoutTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resetTimeout()
|
||||||
|
defer stopTimeout()
|
||||||
|
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for scanner.Scan() {
|
||||||
|
select {
|
||||||
|
case events <- scanEvent{line: scanner.Text()}:
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
select {
|
||||||
|
case events <- scanEvent{err: err}:
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
return nil, usage, acc, nil
|
||||||
|
}
|
||||||
|
resetTimeout()
|
||||||
|
if ev.err != nil {
|
||||||
|
if !errors.Is(ev.err, context.Canceled) && !errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
logger.L().Warn(logPrefix+": read error",
|
||||||
|
zap.Error(ev.err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil, usage, acc, ev.err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, ok := extractOpenAISSEDataLine(ev.line)
|
||||||
|
if !ok || payload == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
|
return nil, usage, acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var event apicompat.ResponsesStreamEvent
|
||||||
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
|
logger.L().Warn(logPrefix+": failed to parse event",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.ProcessEvent(&event)
|
||||||
|
|
||||||
|
if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
|
||||||
|
if event.Response.Usage != nil {
|
||||||
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
|
}
|
||||||
|
return event.Response, usage, acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-timeoutCh:
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
logger.L().Warn(logPrefix+": data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return nil, usage, acc, fmt.Errorf("stream data interval timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
|
||||||
// converts each to Anthropic SSE events, and writes them to the client.
|
// converts each to Anthropic SSE events, and writes them to the client.
|
||||||
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
|
||||||
@ -409,6 +507,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
var usage OpenAIUsage
|
var usage OpenAIUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
firstChunk := true
|
firstChunk := true
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@ -417,6 +516,20 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
// resultWithUsage builds the final result snapshot.
|
// resultWithUsage builds the final result snapshot.
|
||||||
resultWithUsage := func() *OpenAIForwardResult {
|
resultWithUsage := func() *OpenAIForwardResult {
|
||||||
return &OpenAIForwardResult{
|
return &OpenAIForwardResult{
|
||||||
@ -432,7 +545,6 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processDataLine handles a single "data: ..." SSE line from upstream.
|
// processDataLine handles a single "data: ..." SSE line from upstream.
|
||||||
// Returns (clientDisconnected bool).
|
|
||||||
processDataLine := func(payload string) bool {
|
processDataLine := func(payload string) bool {
|
||||||
if firstChunk {
|
if firstChunk {
|
||||||
firstChunk = false
|
firstChunk = false
|
||||||
@ -449,53 +561,58 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract usage from completion events
|
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
|
||||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
|
||||||
event.Response != nil && event.Response.Usage != nil {
|
if isTerminalEvent && event.Response != nil && event.Response.Usage != nil {
|
||||||
usage = OpenAIUsage{
|
usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
|
||||||
InputTokens: event.Response.Usage.InputTokens,
|
|
||||||
OutputTokens: event.Response.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if event.Response.Usage.InputTokensDetails != nil {
|
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to Anthropic events
|
// Convert to Anthropic events
|
||||||
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
|
||||||
for _, evt := range events {
|
if !clientDisconnected {
|
||||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
for _, evt := range events {
|
||||||
if err != nil {
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||||
logger.L().Warn("openai messages stream: failed to marshal event",
|
if err != nil {
|
||||||
zap.Error(err),
|
logger.L().Warn("openai messages stream: failed to marshal event",
|
||||||
zap.String("request_id", requestID),
|
zap.Error(err),
|
||||||
)
|
zap.String("request_id", requestID),
|
||||||
continue
|
)
|
||||||
}
|
continue
|
||||||
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
}
|
||||||
logger.L().Info("openai messages stream: client disconnected",
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
zap.String("request_id", requestID),
|
clientDisconnected = true
|
||||||
)
|
logger.L().Info("openai messages stream: client disconnected, continuing to drain upstream for billing",
|
||||||
return true
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(events) > 0 {
|
if len(events) > 0 && !clientDisconnected {
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
return false
|
return isTerminalEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// finalizeStream sends any remaining Anthropic events and returns the result.
|
// finalizeStream sends any remaining Anthropic events and returns the result.
|
||||||
finalizeStream := func() (*OpenAIForwardResult, error) {
|
finalizeStream := func() (*OpenAIForwardResult, error) {
|
||||||
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
|
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 && !clientDisconnected {
|
||||||
for _, evt := range finalEvents {
|
for _, evt := range finalEvents {
|
||||||
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Fprint(c.Writer, sse) //nolint:errcheck
|
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.L().Info("openai messages stream: client disconnected during final flush",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !clientDisconnected {
|
||||||
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
|
||||||
}
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
}
|
}
|
||||||
@ -509,6 +626,9 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
missingTerminalErr := func() (*OpenAIForwardResult, error) {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||||
|
}
|
||||||
|
|
||||||
// ── Determine keepalive interval ──
|
// ── Determine keepalive interval ──
|
||||||
keepaliveInterval := time.Duration(0)
|
keepaliveInterval := time.Duration(0)
|
||||||
@ -517,18 +637,25 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
|
||||||
if keepaliveInterval <= 0 {
|
if streamInterval <= 0 && keepaliveInterval <= 0 {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handleScanErr(scanner.Err())
|
if err := scanner.Err(); err != nil {
|
||||||
return finalizeStream()
|
handleScanErr(err)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
|
||||||
|
}
|
||||||
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── With keepalive: goroutine + channel + select ──
|
// ── With keepalive: goroutine + channel + select ──
|
||||||
@ -538,6 +665,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}
|
}
|
||||||
events := make(chan scanEvent, 16)
|
events := make(chan scanEvent, 16)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
sendEvent := func(ev scanEvent) bool {
|
sendEvent := func(ev scanEvent) bool {
|
||||||
select {
|
select {
|
||||||
case events <- ev:
|
case events <- ev:
|
||||||
@ -549,6 +678,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(events)
|
defer close(events)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -559,8 +689,15 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
}()
|
}()
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
|
||||||
keepaliveTicker := time.NewTicker(keepaliveInterval)
|
var keepaliveTicker *time.Ticker
|
||||||
defer keepaliveTicker.Stop()
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -568,22 +705,44 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
if !ok {
|
if !ok {
|
||||||
// Upstream closed
|
// Upstream closed
|
||||||
return finalizeStream()
|
return missingTerminalErr()
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
handleScanErr(ev.err)
|
handleScanErr(ev.err)
|
||||||
return finalizeStream()
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||||
}
|
}
|
||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
line := ev.line
|
line := ev.line
|
||||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
payload, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if processDataLine(line[6:]) {
|
if strings.TrimSpace(payload) == "[DONE]" {
|
||||||
return resultWithUsage(), nil
|
return missingTerminalErr()
|
||||||
|
}
|
||||||
|
if processDataLine(payload) {
|
||||||
|
return finalizeStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-keepaliveTicker.C:
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||||
|
}
|
||||||
|
logger.L().Warn("openai messages stream: data interval timeout",
|
||||||
|
zap.String("request_id", requestID),
|
||||||
|
zap.String("model", originalModel),
|
||||||
|
zap.Duration("interval", streamInterval),
|
||||||
|
)
|
||||||
|
return resultWithUsage(), fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -593,7 +752,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
logger.L().Info("openai messages stream: client disconnected during keepalive",
|
||||||
zap.String("request_id", requestID),
|
zap.String("request_id", requestID),
|
||||||
)
|
)
|
||||||
return resultWithUsage(), nil
|
clientDisconnected = true
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
@ -610,3 +770,17 @@ func writeAnthropicError(c *gin.Context, statusCode int, errType, message string
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyOpenAIUsageFromResponsesUsage(usage *apicompat.ResponsesUsage) OpenAIUsage {
|
||||||
|
if usage == nil {
|
||||||
|
return OpenAIUsage{}
|
||||||
|
}
|
||||||
|
result := OpenAIUsage{
|
||||||
|
InputTokens: usage.InputTokens,
|
||||||
|
OutputTokens: usage.OutputTokens,
|
||||||
|
}
|
||||||
|
if usage.InputTokensDetails != nil {
|
||||||
|
result.CacheReadInputTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@ -2601,7 +2601,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
httpInvalidEncryptedContentRetryTried := false
|
httpInvalidEncryptedContentRetryTried := false
|
||||||
for {
|
for {
|
||||||
// Build upstream request
|
// Build upstream request
|
||||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||||
releaseUpstreamCtx()
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2852,7 +2852,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||||
releaseUpstreamCtx()
|
releaseUpstreamCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -307,6 +307,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_CompactUsesJSONAndKeepsNonStreami
|
|||||||
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
require.Contains(t, rec.Body.String(), `"id":"cmp_123"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_passthrough_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": true, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
logSink, restore := captureStructuredLog(t)
|
logSink, restore := captureStructuredLog(t)
|
||||||
@ -405,6 +451,52 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
|||||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthLegacy_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)).WithContext(reqCtx)
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_legacy_ctx"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": false, "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(reqCtx, c, account, originalBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.NoError(t, upstream.lastReq.Context().Err())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user