修复 OpenAI WS 兼容性与 usage 统计
- 对齐 WS 与流式终态 usage 解析,补齐 failed/done/incomplete/cancelled 等事件 - 兼容后续 WS response.create 省略 model,保持模型映射与权限判断一致 - 补齐 passthrough header 透传和图片 usage 字段映射
This commit is contained in:
parent
89d96f4b25
commit
d7bed40dda
@ -4861,7 +4861,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
eventType := gjson.GetBytes(data, "type").String()
|
eventType := gjson.GetBytes(data, "type").String()
|
||||||
if eventType != "response.completed" && eventType != "response.done" &&
|
if eventType != "response.completed" && eventType != "response.done" && eventType != "response.failed" &&
|
||||||
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2218,6 +2218,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
|||||||
require.Equal(t, 15, usage.OutputTokens)
|
require.Equal(t, 15, usage.OutputTokens)
|
||||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
// failed 事件在部分上游路径也会携带已消耗 usage,应与 WS/passthrough 保持一致
|
||||||
|
svc.parseSSEUsage(`{"type":"response.failed","response":{"usage":{"input_tokens":17,"output_tokens":19,"input_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||||
|
require.Equal(t, 17, usage.InputTokens)
|
||||||
|
require.Equal(t, 19, usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
|
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||||
require.Equal(t, 21, usage.InputTokens)
|
require.Equal(t, 21, usage.InputTokens)
|
||||||
require.Equal(t, 8, usage.OutputTokens)
|
require.Equal(t, 8, usage.OutputTokens)
|
||||||
|
|||||||
@ -369,7 +369,12 @@ func openAIWSEventMayContainToolCalls(eventType string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func openAIWSEventShouldParseUsage(eventType string) bool {
|
func openAIWSEventShouldParseUsage(eventType string) bool {
|
||||||
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
|
switch strings.TrimSpace(eventType) {
|
||||||
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
||||||
@ -2484,6 +2489,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
imageInputSize string
|
imageInputSize string
|
||||||
payloadBytes int
|
payloadBytes int
|
||||||
}
|
}
|
||||||
|
ingressSessionOriginalModel := ""
|
||||||
|
|
||||||
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
||||||
next, err := sjson.SetBytes(current, path, value)
|
next, err := sjson.SetBytes(current, path, value)
|
||||||
@ -2547,12 +2553,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
originalModel := strings.TrimSpace(values[1].String())
|
originalModel := strings.TrimSpace(values[1].String())
|
||||||
|
modelMissing := originalModel == ""
|
||||||
if originalModel == "" {
|
if originalModel == "" {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
// 入站 WS 长会话里,部分客户端只在第一轮 response.create 上声明
|
||||||
coderws.StatusPolicyViolation,
|
// model,后续 turn 复用同一 session-level model。为避免因省略
|
||||||
"model is required in response.create payload",
|
// model 直接断开用户连接,这里回落到上一轮已通过校验的客户端模型,
|
||||||
nil,
|
// 并在下方写回上游 payload,保证账号模型映射/fast policy/图片权限
|
||||||
)
|
// 仍按同一模型执行。
|
||||||
|
originalModel = ingressSessionOriginalModel
|
||||||
|
if originalModel == "" {
|
||||||
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"model is required in response.create payload",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
promptCacheKey := strings.TrimSpace(values[2].String())
|
promptCacheKey := strings.TrimSpace(values[2].String())
|
||||||
previousResponseID := strings.TrimSpace(values[3].String())
|
previousResponseID := strings.TrimSpace(values[3].String())
|
||||||
@ -2572,7 +2587,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
normalized = next
|
normalized = next
|
||||||
}
|
}
|
||||||
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
|
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
|
||||||
if upstreamModel != originalModel {
|
if modelMissing || upstreamModel != originalModel {
|
||||||
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
||||||
if setErr != nil {
|
if setErr != nil {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||||||
@ -2602,11 +2617,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
// single integration point for all WS ingress turns (first + follow-up
|
// single integration point for all WS ingress turns (first + follow-up
|
||||||
// frames flow through here).
|
// frames flow through here).
|
||||||
//
|
//
|
||||||
// Model fallback: parseClientPayload above rejects any frame whose
|
// Model fallback: first turn still requires model at the handler layer;
|
||||||
// "model" field is missing (line ~2493-2500), so by the time we
|
// follow-up response.create frames may omit it and then reuse
|
||||||
// reach this point upstreamModel is always derived from a non-empty
|
// ingressSessionOriginalModel. We always write a concrete upstream model
|
||||||
// per-frame model. The capturedSessionModel fallback used in the
|
// before evaluating policy, so whitelist / filter behavior remains stable.
|
||||||
// passthrough adapter is therefore not needed in this path.
|
|
||||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||||
@ -2635,6 +2649,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
normalized = policyApplied
|
normalized = policyApplied
|
||||||
|
ingressSessionOriginalModel = originalModel
|
||||||
|
|
||||||
return openAIWSClientPayload{
|
return openAIWSClientPayload{
|
||||||
payloadRaw: normalized,
|
payloadRaw: normalized,
|
||||||
|
|||||||
@ -39,6 +39,24 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
|
|||||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIWSEventShouldParseUsageTerminalEvents(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, eventType := range []string{
|
||||||
|
"response.completed",
|
||||||
|
"response.done",
|
||||||
|
"response.failed",
|
||||||
|
"response.incomplete",
|
||||||
|
"response.cancelled",
|
||||||
|
"response.canceled",
|
||||||
|
} {
|
||||||
|
require.True(t, openAIWSEventShouldParseUsage(eventType), eventType)
|
||||||
|
require.True(t, openAIWSEventShouldParseUsage(" "+eventType+" "), eventType)
|
||||||
|
}
|
||||||
|
require.False(t, openAIWSEventShouldParseUsage("response.output_text.delta"))
|
||||||
|
require.False(t, openAIWSEventShouldParseUsage(""))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
||||||
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||||
|
|||||||
@ -164,6 +164,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
|||||||
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
|
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCanOmitModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 115,
|
||||||
|
Name: "openai-ingress-omit-model",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"client-model": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"client-model","stream":false}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, firstEvent, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_omit_model_1", gjson.GetBytes(firstEvent, "response.id").String())
|
||||||
|
|
||||||
|
writeCtx, cancelWrite = context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","stream":false,"previous_response_id":"resp_omit_model_1"}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead = context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, secondEvent, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_omit_model_2", gjson.GetBytes(secondEvent, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Len(t, captureConn.writes, 2)
|
||||||
|
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[0]), "model").String())
|
||||||
|
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[1]), "model").String())
|
||||||
|
require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@ -441,6 +575,124 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
|||||||
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughHeadersUsePromptCacheAndTurnState(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
upstreamConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_headers","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPassthroughDialer: captureDialer,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 453,
|
||||||
|
Name: "openai-ingress-passthrough-headers",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
||||||
|
req.Header.Set(openAIWSTurnStateHeader, "turn-state-1")
|
||||||
|
req.Header.Set(openAIWSTurnMetadataHeader, "turn-meta-1")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "oauth-token", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"prompt_cache_key":"pcache_passthrough"}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, event, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_passthrough_headers", gjson.GetBytes(event, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
if serverErr != nil {
|
||||||
|
require.Contains(t, serverErr.Error(), "StatusNormalClosure")
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 passthrough websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, isolateOpenAISessionID(0, "pcache_passthrough"), captureDialer.lastHeaders.Get("session_id"))
|
||||||
|
require.Equal(t, "turn-state-1", captureDialer.lastHeaders.Get(openAIWSTurnStateHeader))
|
||||||
|
require.Equal(t, "turn-meta-1", captureDialer.lastHeaders.Get(openAIWSTurnMetadataHeader))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -727,6 +727,70 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
|
|||||||
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_Forward_WSv2_ResponseDoneUsageParsed(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||||
|
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.done","response":{"id":"resp_done_usage","model":"gpt-5.1","usage":{"input_tokens":13,"output_tokens":8,"input_tokens_details":{"cached_tokens":5},"cache_creation_input_tokens":2,"output_tokens_details":{"image_tokens":4}}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 32,
|
||||||
|
Name: "openai-ws-done",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hi"}]}`)
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "resp_done_usage", result.RequestID)
|
||||||
|
require.Equal(t, 13, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 2, result.Usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.ImageOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ type Usage struct {
|
|||||||
OutputTokens int
|
OutputTokens int
|
||||||
CacheCreationInputTokens int
|
CacheCreationInputTokens int
|
||||||
CacheReadInputTokens int
|
CacheReadInputTokens int
|
||||||
|
ImageOutputTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelayResult struct {
|
type RelayResult struct {
|
||||||
@ -756,8 +757,21 @@ func parseUsageAndAccumulate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||||
|
if !inputResult.Exists() {
|
||||||
|
inputResult = gjson.GetBytes(message, "response.usage.prompt_tokens")
|
||||||
|
}
|
||||||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||||
|
if !outputResult.Exists() {
|
||||||
|
outputResult = gjson.GetBytes(message, "response.usage.completion_tokens")
|
||||||
|
}
|
||||||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||||
|
if !cachedResult.Exists() {
|
||||||
|
cachedResult = gjson.GetBytes(message, "response.usage.prompt_tokens_details.cached_tokens")
|
||||||
|
}
|
||||||
|
imageTokens := usageResult.Get("output_tokens_details.image_tokens").Int()
|
||||||
|
if imageTokens == 0 {
|
||||||
|
imageTokens = usageResult.Get("completion_tokens_details.image_tokens").Int()
|
||||||
|
}
|
||||||
|
|
||||||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||||
@ -771,14 +785,18 @@ func parseUsageAndAccumulate(
|
|||||||
return Usage{}
|
return Usage{}
|
||||||
}
|
}
|
||||||
parsedUsage := Usage{
|
parsedUsage := Usage{
|
||||||
InputTokens: inputTokens,
|
InputTokens: inputTokens,
|
||||||
OutputTokens: outputTokens,
|
OutputTokens: outputTokens,
|
||||||
CacheReadInputTokens: cachedTokens,
|
CacheCreationInputTokens: int(usageResult.Get("cache_creation_input_tokens").Int()),
|
||||||
|
CacheReadInputTokens: cachedTokens,
|
||||||
|
ImageOutputTokens: int(imageTokens),
|
||||||
}
|
}
|
||||||
|
|
||||||
state.usage.InputTokens += parsedUsage.InputTokens
|
state.usage.InputTokens += parsedUsage.InputTokens
|
||||||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||||
|
state.usage.CacheCreationInputTokens += parsedUsage.CacheCreationInputTokens
|
||||||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||||
|
state.usage.ImageOutputTokens += parsedUsage.ImageOutputTokens
|
||||||
return parsedUsage
|
return parsedUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -840,7 +858,7 @@ func isTerminalEvent(eventType string) bool {
|
|||||||
|
|
||||||
func shouldParseUsage(eventType string) bool {
|
func shouldParseUsage(eventType string) bool {
|
||||||
switch eventType {
|
switch eventType {
|
||||||
case "response.completed", "response.done", "response.failed":
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -300,20 +300,41 @@ func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
|||||||
require.Equal(t, 0, state.usage.OutputTokens)
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1},"cache_creation_input_tokens":4,"output_tokens_details":{"image_tokens":3}}}}`), "response.completed", nil)
|
||||||
require.Equal(t, 2, state.usage.InputTokens)
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
require.Equal(t, 1, state.usage.OutputTokens)
|
require.Equal(t, 1, state.usage.OutputTokens)
|
||||||
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 4, state.usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, 3, state.usage.ImageOutputTokens)
|
||||||
|
|
||||||
result := &RelayResult{}
|
result := &RelayResult{}
|
||||||
enrichResult(result, state, 5*time.Millisecond)
|
enrichResult(result, state, 5*time.Millisecond)
|
||||||
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, state.usage.CacheCreationInputTokens, result.Usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, state.usage.ImageOutputTokens, result.Usage.ImageOutputTokens)
|
||||||
require.Equal(t, 5*time.Millisecond, result.Duration)
|
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||||
require.Equal(t, 2, state.usage.InputTokens)
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
enrichResult(nil, state, 0)
|
enrichResult(nil, state, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseUsageAndAccumulateAcceptsChatUsageAliases(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
got := parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.done","response":{"usage":{"prompt_tokens":12,"completion_tokens":6,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"image_tokens":2}}}}`),
|
||||||
|
"response.done",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 12, got.InputTokens)
|
||||||
|
require.Equal(t, 6, got.OutputTokens)
|
||||||
|
require.Equal(t, 4, got.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 2, got.ImageOutputTokens)
|
||||||
|
require.Equal(t, got, state.usage)
|
||||||
|
}
|
||||||
|
|
||||||
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -377,6 +398,23 @@ func TestIsTokenEventCoverageBranches(t *testing.T) {
|
|||||||
require.True(t, isTokenEvent("response.done"))
|
require.True(t, isTokenEvent("response.done"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldParseUsageTerminalEvents(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, eventType := range []string{
|
||||||
|
"response.completed",
|
||||||
|
"response.done",
|
||||||
|
"response.failed",
|
||||||
|
"response.incomplete",
|
||||||
|
"response.cancelled",
|
||||||
|
"response.canceled",
|
||||||
|
} {
|
||||||
|
require.True(t, shouldParseUsage(eventType), eventType)
|
||||||
|
}
|
||||||
|
require.False(t, shouldParseUsage("response.output_text.delta"))
|
||||||
|
require.False(t, shouldParseUsage(""))
|
||||||
|
}
|
||||||
|
|
||||||
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@ -312,6 +312,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||||
|
promptCacheKey := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "prompt_cache_key").String())
|
||||||
|
|
||||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -338,7 +339,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||||
isCodexCLI = true
|
isCodexCLI = true
|
||||||
}
|
}
|
||||||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
turnState := ""
|
||||||
|
turnMetadata := ""
|
||||||
|
if c != nil {
|
||||||
|
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
|
||||||
|
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
|
||||||
|
}
|
||||||
|
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
@ -519,6 +526,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
OutputTokens: turn.Usage.OutputTokens,
|
OutputTokens: turn.Usage.OutputTokens,
|
||||||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
|
ImageOutputTokens: turn.Usage.ImageOutputTokens,
|
||||||
},
|
},
|
||||||
Model: turn.RequestModel,
|
Model: turn.RequestModel,
|
||||||
ServiceTier: usageMeta.serviceTier.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
@ -593,6 +601,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
OutputTokens: relayResult.Usage.OutputTokens,
|
OutputTokens: relayResult.Usage.OutputTokens,
|
||||||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
|
ImageOutputTokens: relayResult.Usage.ImageOutputTokens,
|
||||||
},
|
},
|
||||||
Model: relayResult.RequestModel,
|
Model: relayResult.RequestModel,
|
||||||
ServiceTier: usageMeta.serviceTier.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user