diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 623a64e9..6d618ee4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4872,7 +4872,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag return } eventType := gjson.GetBytes(data, "type").String() - if eventType != "response.completed" && eventType != "response.done" && + if eventType != "response.completed" && eventType != "response.done" && eventType != "response.failed" && eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" { return } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 8bed920d..8aad2fa6 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -2218,6 +2218,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 15, usage.OutputTokens) require.Equal(t, 4, usage.CacheReadInputTokens) + // failed 事件在部分上游路径也会携带已消耗 usage,应与 WS/passthrough 保持一致 + svc.parseSSEUsage(`{"type":"response.failed","response":{"usage":{"input_tokens":17,"output_tokens":19,"input_tokens_details":{"cached_tokens":6}}}}`, usage) + require.Equal(t, 17, usage.InputTokens) + require.Equal(t, 19, usage.OutputTokens) + require.Equal(t, 6, usage.CacheReadInputTokens) + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage) require.Equal(t, 21, usage.InputTokens) require.Equal(t, 8, usage.OutputTokens) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 6eea0191..aa8e326b 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -369,7 +369,12 @@ func openAIWSEventMayContainToolCalls(eventType string) bool { } func openAIWSEventShouldParseUsage(eventType string) bool { - return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } } func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { @@ -2484,6 +2489,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( imageInputSize string payloadBytes int } + ingressSessionOriginalModel := "" applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { next, err := sjson.SetBytes(current, path, value) @@ -2547,12 +2553,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } originalModel := strings.TrimSpace(values[1].String()) + modelMissing := originalModel == "" if originalModel == "" { - return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( - coderws.StatusPolicyViolation, - "model is required in response.create payload", - nil, - ) + // 入站 WS 长会话里,部分客户端只在第一轮 response.create 上声明 + // model,后续 turn 复用同一 session-level model。为避免因省略 + // model 直接断开用户连接,这里回落到上一轮已通过校验的客户端模型, + // 并在下方写回上游 payload,保证账号模型映射/fast policy/图片权限 + // 仍按同一模型执行。 + originalModel = ingressSessionOriginalModel + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } } promptCacheKey := strings.TrimSpace(values[2].String()) previousResponseID := strings.TrimSpace(values[3].String()) @@ -2572,7 +2587,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( normalized = next } upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) - if upstreamModel != originalModel { + if modelMissing || upstreamModel != originalModel { next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) @@ -2602,11 +2617,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( // single integration point for all WS ingress turns (first + follow-up // frames flow through here). // - // Model fallback: parseClientPayload above rejects any frame whose - // "model" field is missing (line ~2493-2500), so by the time we - // reach this point upstreamModel is always derived from a non-empty - // per-frame model. The capturedSessionModel fallback used in the - // passthrough adapter is therefore not needed in this path. + // Model fallback: first turn still requires model at the handler layer; + // follow-up response.create frames may omit it and then reuse + // ingressSessionOriginalModel. We always write a concrete upstream model + // before evaluating policy, so whitelist / filter behavior remains stable. policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized) if policyErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr) @@ -2635,6 +2649,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } normalized = policyApplied + ingressSessionOriginalModel = originalModel return openAIWSClientPayload{ payloadRaw: normalized, diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go index 0350bde9..2622f7f2 100644 --- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -39,6 +39,24 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { require.Equal(t, 4, usage.CacheReadInputTokens) } +func TestOpenAIWSEventShouldParseUsageTerminalEvents(t *testing.T) { + t.Parallel() + + for _, eventType := range []string{ + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled", + } { + require.True(t, openAIWSEventShouldParseUsage(eventType), eventType) + require.True(t, openAIWSEventShouldParseUsage(" "+eventType+" "), eventType) + } + require.False(t, openAIWSEventShouldParseUsage("response.output_text.delta")) + require.False(t, openAIWSEventShouldParseUsage("")) +} + func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index edb6fbcd..b7f1bc4f 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -164,6 +164,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCanOmitModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_omit_model_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_omit_model_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 115, + Name: "openai-ingress-omit-model", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "client-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"client-model","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, firstEvent, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "resp_omit_model_1", gjson.GetBytes(firstEvent, "response.id").String()) + + writeCtx, cancelWrite = context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","stream":false,"previous_response_id":"resp_omit_model_1"}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead = context.WithTimeout(context.Background(), 3*time.Second) + _, secondEvent, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "resp_omit_model_2", gjson.GetBytes(secondEvent, "response.id").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Len(t, captureConn.writes, 2) + require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[0]), "model").String()) + require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[1]), "model").String()) + require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String()) +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { gin.SetMode(gin.TestMode) @@ -441,6 +575,124 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughHeadersUsePromptCacheAndTurnState(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_passthrough_headers","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: upstreamConn} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPassthroughDialer: captureDialer, + } + account := &Account{ + ID: 453, + Name: "openai-ingress-passthrough-headers", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + req.Header.Set(openAIWSTurnStateHeader, "turn-state-1") + req.Header.Set(openAIWSTurnMetadataHeader, "turn-meta-1") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "oauth-token", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"prompt_cache_key":"pcache_passthrough"}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "resp_passthrough_headers", gjson.GetBytes(event, "response.id").String()) + _ = clientConn.Close(coderws.StatusNormalClosure, "done") + + select { + case serverErr := <-serverErrCh: + if serverErr != nil { + require.Contains(t, serverErr.Error(), "StatusNormalClosure") + } + case <-time.After(5 * time.Second): + t.Fatal("等待 passthrough websocket 结束超时") + } + + require.Equal(t, isolateOpenAISessionID(0, "pcache_passthrough"), captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, "turn-state-1", captureDialer.lastHeaders.Get(openAIWSTurnStateHeader)) + require.Equal(t, "turn-meta-1", captureDialer.lastHeaders.Get(openAIWSTurnMetadataHeader)) +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index cd816533..e949560f 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -727,6 +727,70 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) } +func TestOpenAIGatewayService_Forward_WSv2_ResponseDoneUsageParsed(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.done","response":{"id":"resp_done_usage","model":"gpt-5.1","usage":{"input_tokens":13,"output_tokens":8,"input_tokens_details":{"cached_tokens":5},"cache_creation_input_tokens":2,"output_tokens_details":{"image_tokens":4}}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 32, + Name: "openai-ws-done", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hi"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_done_usage", result.RequestID) + require.Equal(t, 13, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.CacheReadInputTokens) + require.Equal(t, 2, result.Usage.CacheCreationInputTokens) + require.Equal(t, 4, result.Usage.ImageOutputTokens) +} + func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go index 35c7569d..6aba3b7d 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -25,6 +25,7 @@ type Usage struct { OutputTokens int CacheCreationInputTokens int CacheReadInputTokens int + ImageOutputTokens int } type RelayResult struct { @@ -756,8 +757,21 @@ func parseUsageAndAccumulate( } inputResult := gjson.GetBytes(message, "response.usage.input_tokens") + if !inputResult.Exists() { + inputResult = gjson.GetBytes(message, "response.usage.prompt_tokens") + } outputResult := gjson.GetBytes(message, "response.usage.output_tokens") + if !outputResult.Exists() { + outputResult = gjson.GetBytes(message, "response.usage.completion_tokens") + } cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens") + if !cachedResult.Exists() { + cachedResult = gjson.GetBytes(message, "response.usage.prompt_tokens_details.cached_tokens") + } + imageTokens := usageResult.Get("output_tokens_details.image_tokens").Int() + if imageTokens == 0 { + imageTokens = usageResult.Get("completion_tokens_details.image_tokens").Int() + } inputTokens, inputOK := parseUsageIntField(inputResult, true) outputTokens, outputOK := parseUsageIntField(outputResult, true) @@ -771,14 +785,18 @@ func parseUsageAndAccumulate( return Usage{} } parsedUsage := Usage{ - InputTokens: inputTokens, - OutputTokens: outputTokens, - CacheReadInputTokens: cachedTokens, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheCreationInputTokens: int(usageResult.Get("cache_creation_input_tokens").Int()), + CacheReadInputTokens: cachedTokens, + ImageOutputTokens: int(imageTokens), } state.usage.InputTokens += parsedUsage.InputTokens state.usage.OutputTokens += parsedUsage.OutputTokens + state.usage.CacheCreationInputTokens += parsedUsage.CacheCreationInputTokens state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens + state.usage.ImageOutputTokens += parsedUsage.ImageOutputTokens return parsedUsage } @@ -840,7 +858,7 @@ func isTerminalEvent(eventType string) bool { func shouldParseUsage(eventType string) bool { switch eventType { - case "response.completed", "response.done", "response.failed": + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": return true default: return false diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go index 52104482..13c51f66 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -300,20 +300,41 @@ func TestParseUsageAndEnrichCoverage(t *testing.T) { require.Equal(t, 0, state.usage.OutputTokens) require.Equal(t, 0, state.usage.CacheReadInputTokens) - parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil) + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1},"cache_creation_input_tokens":4,"output_tokens_details":{"image_tokens":3}}}}`), "response.completed", nil) require.Equal(t, 2, state.usage.InputTokens) require.Equal(t, 1, state.usage.OutputTokens) require.Equal(t, 1, state.usage.CacheReadInputTokens) + require.Equal(t, 4, state.usage.CacheCreationInputTokens) + require.Equal(t, 3, state.usage.ImageOutputTokens) result := &RelayResult{} enrichResult(result, state, 5*time.Millisecond) require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens) + require.Equal(t, state.usage.CacheCreationInputTokens, result.Usage.CacheCreationInputTokens) + require.Equal(t, state.usage.ImageOutputTokens, result.Usage.ImageOutputTokens) require.Equal(t, 5*time.Millisecond, result.Duration) parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil) require.Equal(t, 2, state.usage.InputTokens) enrichResult(nil, state, 0) } +func TestParseUsageAndAccumulateAcceptsChatUsageAliases(t *testing.T) { + t.Parallel() + + state := &relayState{} + got := parseUsageAndAccumulate( + state, + []byte(`{"type":"response.done","response":{"usage":{"prompt_tokens":12,"completion_tokens":6,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"image_tokens":2}}}}`), + "response.done", + nil, + ) + require.Equal(t, 12, got.InputTokens) + require.Equal(t, 6, got.OutputTokens) + require.Equal(t, 4, got.CacheReadInputTokens) + require.Equal(t, 2, got.ImageOutputTokens) + require.Equal(t, got, state.usage) +} + func TestEmitTurnCompleteCoverage(t *testing.T) { t.Parallel() @@ -377,6 +398,23 @@ func TestIsTokenEventCoverageBranches(t *testing.T) { require.True(t, isTokenEvent("response.done")) } +func TestShouldParseUsageTerminalEvents(t *testing.T) { + t.Parallel() + + for _, eventType := range []string{ + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled", + } { + require.True(t, shouldParseUsage(eventType), eventType) + } + require.False(t, shouldParseUsage("response.output_text.delta")) + require.False(t, shouldParseUsage("")) +} + func TestRelayTurnTimingHelpersCoverage(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 17543dc0..c93d0981 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -312,6 +312,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // goroutine)和 OnTurnComplete / final result(runUpstreamToClient // goroutine)之间同步当前 turn 的 usage metadata。 usageMeta.initFromFirstFrame(firstClientMessage) + promptCacheKey := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "prompt_cache_key").String()) wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { @@ -338,7 +339,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { isCodexCLI = true } - headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") + turnState := "" + turnMetadata := "" + if c != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, turnMetadata, promptCacheKey) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() @@ -519,6 +526,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( OutputTokens: turn.Usage.OutputTokens, CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, CacheReadInputTokens: turn.Usage.CacheReadInputTokens, + ImageOutputTokens: turn.Usage.ImageOutputTokens, }, Model: turn.RequestModel, ServiceTier: usageMeta.serviceTier.Load(), @@ -593,6 +601,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( OutputTokens: relayResult.Usage.OutputTokens, CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, + ImageOutputTokens: relayResult.Usage.ImageOutputTokens, }, Model: relayResult.RequestModel, ServiceTier: usageMeta.serviceTier.Load(),