Merge pull request #2836 from siyuan-123/fix/openai-ws-compat-usage
修复 OpenAI WS 兼容性与 usage 统计
This commit is contained in:
commit
16842c2f8b
@ -4872,7 +4872,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
return
|
||||
}
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" &&
|
||||
if eventType != "response.completed" && eventType != "response.done" && eventType != "response.failed" &&
|
||||
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -2218,6 +2218,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||
require.Equal(t, 15, usage.OutputTokens)
|
||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||
|
||||
// failed 事件在部分上游路径也会携带已消耗 usage,应与 WS/passthrough 保持一致
|
||||
svc.parseSSEUsage(`{"type":"response.failed","response":{"usage":{"input_tokens":17,"output_tokens":19,"input_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||
require.Equal(t, 17, usage.InputTokens)
|
||||
require.Equal(t, 19, usage.OutputTokens)
|
||||
require.Equal(t, 6, usage.CacheReadInputTokens)
|
||||
|
||||
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||
require.Equal(t, 21, usage.InputTokens)
|
||||
require.Equal(t, 8, usage.OutputTokens)
|
||||
|
||||
@ -369,7 +369,12 @@ func openAIWSEventMayContainToolCalls(eventType string) bool {
|
||||
}
|
||||
|
||||
func openAIWSEventShouldParseUsage(eventType string) bool {
|
||||
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
|
||||
switch strings.TrimSpace(eventType) {
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
||||
@ -2484,6 +2489,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
imageInputSize string
|
||||
payloadBytes int
|
||||
}
|
||||
ingressSessionOriginalModel := ""
|
||||
|
||||
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
||||
next, err := sjson.SetBytes(current, path, value)
|
||||
@ -2547,12 +2553,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
|
||||
originalModel := strings.TrimSpace(values[1].String())
|
||||
modelMissing := originalModel == ""
|
||||
if originalModel == "" {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"model is required in response.create payload",
|
||||
nil,
|
||||
)
|
||||
// 入站 WS 长会话里,部分客户端只在第一轮 response.create 上声明
|
||||
// model,后续 turn 复用同一 session-level model。为避免因省略
|
||||
// model 直接断开用户连接,这里回落到上一轮已通过校验的客户端模型,
|
||||
// 并在下方写回上游 payload,保证账号模型映射/fast policy/图片权限
|
||||
// 仍按同一模型执行。
|
||||
originalModel = ingressSessionOriginalModel
|
||||
if originalModel == "" {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
"model is required in response.create payload",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
promptCacheKey := strings.TrimSpace(values[2].String())
|
||||
previousResponseID := strings.TrimSpace(values[3].String())
|
||||
@ -2572,7 +2587,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
normalized = next
|
||||
}
|
||||
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
|
||||
if upstreamModel != originalModel {
|
||||
if modelMissing || upstreamModel != originalModel {
|
||||
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
||||
if setErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||||
@ -2602,11 +2617,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
// single integration point for all WS ingress turns (first + follow-up
|
||||
// frames flow through here).
|
||||
//
|
||||
// Model fallback: parseClientPayload above rejects any frame whose
|
||||
// "model" field is missing (line ~2493-2500), so by the time we
|
||||
// reach this point upstreamModel is always derived from a non-empty
|
||||
// per-frame model. The capturedSessionModel fallback used in the
|
||||
// passthrough adapter is therefore not needed in this path.
|
||||
// Model fallback: first turn still requires model at the handler layer;
|
||||
// follow-up response.create frames may omit it and then reuse
|
||||
// ingressSessionOriginalModel. We always write a concrete upstream model
|
||||
// before evaluating policy, so whitelist / filter behavior remains stable.
|
||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||
if policyErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||
@ -2635,6 +2649,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
)
|
||||
}
|
||||
normalized = policyApplied
|
||||
ingressSessionOriginalModel = originalModel
|
||||
|
||||
return openAIWSClientPayload{
|
||||
payloadRaw: normalized,
|
||||
|
||||
@ -39,6 +39,24 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
|
||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIWSEventShouldParseUsageTerminalEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, eventType := range []string{
|
||||
"response.completed",
|
||||
"response.done",
|
||||
"response.failed",
|
||||
"response.incomplete",
|
||||
"response.cancelled",
|
||||
"response.canceled",
|
||||
} {
|
||||
require.True(t, openAIWSEventShouldParseUsage(eventType), eventType)
|
||||
require.True(t, openAIWSEventShouldParseUsage(" "+eventType+" "), eventType)
|
||||
}
|
||||
require.False(t, openAIWSEventShouldParseUsage("response.output_text.delta"))
|
||||
require.False(t, openAIWSEventShouldParseUsage(""))
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
||||
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
|
||||
@ -164,6 +164,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
||||
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCanOmitModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Name: "openai-ingress-omit-model",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"model_mapping": map[string]any{
|
||||
"client-model": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"client-model","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, firstEvent, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, "resp_omit_model_1", gjson.GetBytes(firstEvent, "response.id").String())
|
||||
|
||||
writeCtx, cancelWrite = context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","stream":false,"previous_response_id":"resp_omit_model_1"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead = context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, secondEvent, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, "resp_omit_model_2", gjson.GetBytes(secondEvent, "response.id").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[0]), "model").String())
|
||||
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[1]), "model").String())
|
||||
require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -441,6 +575,124 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughHeadersUsePromptCacheAndTurnState(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
upstreamConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_headers","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPassthroughDialer: captureDialer,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 453,
|
||||
Name: "openai-ingress-passthrough-headers",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
||||
req.Header.Set(openAIWSTurnStateHeader, "turn-state-1")
|
||||
req.Header.Set(openAIWSTurnMetadataHeader, "turn-meta-1")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "oauth-token", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"prompt_cache_key":"pcache_passthrough"}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, readErr := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, "resp_passthrough_headers", gjson.GetBytes(event, "response.id").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
if serverErr != nil {
|
||||
require.Contains(t, serverErr.Error(), "StatusNormalClosure")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 passthrough websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, isolateOpenAISessionID(0, "pcache_passthrough"), captureDialer.lastHeaders.Get("session_id"))
|
||||
require.Equal(t, "turn-state-1", captureDialer.lastHeaders.Get(openAIWSTurnStateHeader))
|
||||
require.Equal(t, "turn-meta-1", captureDialer.lastHeaders.Get(openAIWSTurnMetadataHeader))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -727,6 +727,70 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
|
||||
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2_ResponseDoneUsageParsed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.done","response":{"id":"resp_done_usage","model":"gpt-5.1","usage":{"input_tokens":13,"output_tokens":8,"input_tokens_details":{"cached_tokens":5},"cache_creation_input_tokens":2,"output_tokens_details":{"image_tokens":4}}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 32,
|
||||
Name: "openai-ws-done",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hi"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_done_usage", result.RequestID)
|
||||
require.Equal(t, 13, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 5, result.Usage.CacheReadInputTokens)
|
||||
require.Equal(t, 2, result.Usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 4, result.Usage.ImageOutputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ type Usage struct {
|
||||
OutputTokens int
|
||||
CacheCreationInputTokens int
|
||||
CacheReadInputTokens int
|
||||
ImageOutputTokens int
|
||||
}
|
||||
|
||||
type RelayResult struct {
|
||||
@ -756,8 +757,21 @@ func parseUsageAndAccumulate(
|
||||
}
|
||||
|
||||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||
if !inputResult.Exists() {
|
||||
inputResult = gjson.GetBytes(message, "response.usage.prompt_tokens")
|
||||
}
|
||||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||
if !outputResult.Exists() {
|
||||
outputResult = gjson.GetBytes(message, "response.usage.completion_tokens")
|
||||
}
|
||||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||
if !cachedResult.Exists() {
|
||||
cachedResult = gjson.GetBytes(message, "response.usage.prompt_tokens_details.cached_tokens")
|
||||
}
|
||||
imageTokens := usageResult.Get("output_tokens_details.image_tokens").Int()
|
||||
if imageTokens == 0 {
|
||||
imageTokens = usageResult.Get("completion_tokens_details.image_tokens").Int()
|
||||
}
|
||||
|
||||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||
@ -771,14 +785,18 @@ func parseUsageAndAccumulate(
|
||||
return Usage{}
|
||||
}
|
||||
parsedUsage := Usage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
CacheCreationInputTokens: int(usageResult.Get("cache_creation_input_tokens").Int()),
|
||||
CacheReadInputTokens: cachedTokens,
|
||||
ImageOutputTokens: int(imageTokens),
|
||||
}
|
||||
|
||||
state.usage.InputTokens += parsedUsage.InputTokens
|
||||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||
state.usage.CacheCreationInputTokens += parsedUsage.CacheCreationInputTokens
|
||||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||
state.usage.ImageOutputTokens += parsedUsage.ImageOutputTokens
|
||||
return parsedUsage
|
||||
}
|
||||
|
||||
@ -840,7 +858,7 @@ func isTerminalEvent(eventType string) bool {
|
||||
|
||||
func shouldParseUsage(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@ -300,20 +300,41 @@ func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
||||
require.Equal(t, 0, state.usage.OutputTokens)
|
||||
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1},"cache_creation_input_tokens":4,"output_tokens_details":{"image_tokens":3}}}}`), "response.completed", nil)
|
||||
require.Equal(t, 2, state.usage.InputTokens)
|
||||
require.Equal(t, 1, state.usage.OutputTokens)
|
||||
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||
require.Equal(t, 4, state.usage.CacheCreationInputTokens)
|
||||
require.Equal(t, 3, state.usage.ImageOutputTokens)
|
||||
|
||||
result := &RelayResult{}
|
||||
enrichResult(result, state, 5*time.Millisecond)
|
||||
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||
require.Equal(t, state.usage.CacheCreationInputTokens, result.Usage.CacheCreationInputTokens)
|
||||
require.Equal(t, state.usage.ImageOutputTokens, result.Usage.ImageOutputTokens)
|
||||
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||
require.Equal(t, 2, state.usage.InputTokens)
|
||||
enrichResult(nil, state, 0)
|
||||
}
|
||||
|
||||
func TestParseUsageAndAccumulateAcceptsChatUsageAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &relayState{}
|
||||
got := parseUsageAndAccumulate(
|
||||
state,
|
||||
[]byte(`{"type":"response.done","response":{"usage":{"prompt_tokens":12,"completion_tokens":6,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"image_tokens":2}}}}`),
|
||||
"response.done",
|
||||
nil,
|
||||
)
|
||||
require.Equal(t, 12, got.InputTokens)
|
||||
require.Equal(t, 6, got.OutputTokens)
|
||||
require.Equal(t, 4, got.CacheReadInputTokens)
|
||||
require.Equal(t, 2, got.ImageOutputTokens)
|
||||
require.Equal(t, got, state.usage)
|
||||
}
|
||||
|
||||
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -377,6 +398,23 @@ func TestIsTokenEventCoverageBranches(t *testing.T) {
|
||||
require.True(t, isTokenEvent("response.done"))
|
||||
}
|
||||
|
||||
func TestShouldParseUsageTerminalEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, eventType := range []string{
|
||||
"response.completed",
|
||||
"response.done",
|
||||
"response.failed",
|
||||
"response.incomplete",
|
||||
"response.cancelled",
|
||||
"response.canceled",
|
||||
} {
|
||||
require.True(t, shouldParseUsage(eventType), eventType)
|
||||
}
|
||||
require.False(t, shouldParseUsage("response.output_text.delta"))
|
||||
require.False(t, shouldParseUsage(""))
|
||||
}
|
||||
|
||||
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -312,6 +312,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||
promptCacheKey := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "prompt_cache_key").String())
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
@ -338,7 +339,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
isCodexCLI = true
|
||||
}
|
||||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||||
turnState := ""
|
||||
turnMetadata := ""
|
||||
if c != nil {
|
||||
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
|
||||
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
|
||||
}
|
||||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
@ -519,6 +526,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
OutputTokens: turn.Usage.OutputTokens,
|
||||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: turn.Usage.ImageOutputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
@ -593,6 +601,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
OutputTokens: relayResult.Usage.OutputTokens,
|
||||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: relayResult.Usage.ImageOutputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user