diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index c0f98de4..7d503f5a 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -20,6 +20,32 @@ type FunctionCallOutputValidation struct { HasItemReferenceForAllCallIDs bool } +func isCodexToolCallContextItemType(typ string) bool { + switch strings.TrimSpace(typ) { + case "tool_call", + "function_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "mcp_tool_call": + return true + default: + return false + } +} + +func isCodexToolCallOutputItemType(typ string) bool { + switch strings.TrimSpace(typ) { + case "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output": + return true + default: + return false + } +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、 // 或显式声明 tools/tool_choice。 @@ -53,7 +79,9 @@ func NeedsToolContinuation(reqBody map[string]any) bool { return false } -// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +// AnalyzeToolContinuationSignals 单次遍历 input,提取工具输出/工具调用上下文/item_reference 相关信号。 +// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖 Codex 的所有工具输出 +// (function_call_output/tool_search_output/custom_tool_call_output/mcp_tool_call_output)。 func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { signals := ToolContinuationSignals{} if reqBody == nil { @@ -73,13 +101,13 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "tool_call", "function_call": + switch { + case isCodexToolCallContextItemType(itemType): callID, _ := itemMap["call_id"].(string) if strings.TrimSpace(callID) != "" { signals.HasToolCallContext = true } - case "function_call_output": + case isCodexToolCallOutputItemType(itemType): signals.HasFunctionCallOutput = true callID, _ := itemMap["call_id"].(string) callID = strings.TrimSpace(callID) @@ -91,7 +119,7 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign callIDs = make(map[string]struct{}) } callIDs[callID] = struct{}{} - case "item_reference": + case itemType == "item_reference": signals.HasItemReference = true idValue, _ := itemMap["id"].(string) idValue = strings.TrimSpace(idValue) @@ -123,9 +151,10 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign } // ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: -// 1) 无 function_call_output 直接返回 -// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 1) 无工具输出直接返回 +// 2) 若已存在工具调用上下文则提前返回 // 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖所有 Codex 工具输出。 func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { result := FunctionCallOutputValidation{} if reqBody == nil { @@ -142,10 +171,10 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "function_call_output": + switch { + case isCodexToolCallOutputItemType(itemType): result.HasFunctionCallOutput = true - case "tool_call", "function_call": + case isCodexToolCallContextItemType(itemType): callID, _ := itemMap["call_id"].(string) if strings.TrimSpace(callID) != "" { result.HasToolCallContext = true @@ -168,8 +197,8 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "function_call_output": + switch { + case isCodexToolCallOutputItemType(itemType): callID, _ := itemMap["call_id"].(string) callID = strings.TrimSpace(callID) if callID == "" { @@ -177,7 +206,7 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } callIDs[callID] = struct{}{} - case "item_reference": + case itemType == "item_reference": idValue, _ := itemMap["id"].(string) idValue = strings.TrimSpace(idValue) if idValue == "" { @@ -201,24 +230,25 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu return result } -// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 +// HasFunctionCallOutput 判断 input 是否包含任意 Codex 工具输出,用于触发续链校验。 +// 名称保留 function_call_output 是为了兼容既有调用点。 func HasFunctionCallOutput(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput } -// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, -// 用于判断 function_call_output 是否具备可关联的上下文。 +// HasToolCallContext 判断 input 是否包含带 call_id 的工具调用上下文, +// 用于判断工具输出是否具备可关联的上下文。 func HasToolCallContext(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext } -// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 +// FunctionCallOutputCallIDs 提取 input 中工具输出的 call_id 集合。 // 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 func FunctionCallOutputCallIDs(reqBody map[string]any) []string { return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs } -// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 +// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的工具输出。 func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go index 3f415d9d..0e0552f6 100644 --- a/backend/internal/service/openai_tool_continuation_test.go +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -38,41 +38,57 @@ func TestNeedsToolContinuationSignals(t *testing.T) { } func TestHasFunctionCallOutput(t *testing.T) { - // 仅当 input 中存在 function_call_output 才视为续链输出。 + // 所有 Codex 工具输出都应视为续链输出,避免 WS 续链时丢失 previous_response_id。 require.False(t, HasFunctionCallOutput(nil)) - require.True(t, HasFunctionCallOutput(map[string]any{ - "input": []any{map[string]any{"type": "function_call_output"}}, - })) + for _, typ := range []string{ + "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output", + } { + require.True(t, HasFunctionCallOutput(map[string]any{ + "input": []any{map[string]any{"type": typ}}, + }), typ) + } require.False(t, HasFunctionCallOutput(map[string]any{ "input": "text", })) } func TestHasToolCallContext(t *testing.T) { - // tool_call/function_call 必须包含 call_id,才能作为可关联上下文。 + // 工具调用上下文必须包含 call_id,才能作为可关联上下文。 require.False(t, HasToolCallContext(nil)) - require.True(t, HasToolCallContext(map[string]any{ - "input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}}, - })) - require.True(t, HasToolCallContext(map[string]any{ - "input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}}, - })) + for _, typ := range []string{ + "tool_call", + "function_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "mcp_tool_call", + } { + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": typ, "call_id": "call_1"}}, + }), typ) + } require.False(t, HasToolCallContext(map[string]any{ "input": []any{map[string]any{"type": "tool_call"}}, })) } func TestFunctionCallOutputCallIDs(t *testing.T) { - // 仅提取非空 call_id,去重后返回。 + // 仅提取工具输出的非空 call_id,去重后返回。 require.Empty(t, FunctionCallOutputCallIDs(nil)) callIDs := FunctionCallOutputCallIDs(map[string]any{ "input": []any{ map[string]any{"type": "function_call_output", "call_id": "call_1"}, + map[string]any{"type": "tool_search_output", "call_id": "call_search"}, + map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom"}, + map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp"}, map[string]any{"type": "function_call_output", "call_id": ""}, map[string]any{"type": "function_call_output", "call_id": "call_1"}, }, }) - require.ElementsMatch(t, []string{"call_1"}, callIDs) + require.ElementsMatch(t, []string{"call_1", "call_search", "call_custom", "call_mcp"}, callIDs) } func TestHasFunctionCallOutputMissingCallID(t *testing.T) { @@ -80,8 +96,11 @@ func TestHasFunctionCallOutputMissingCallID(t *testing.T) { require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ "input": []any{map[string]any{"type": "function_call_output"}}, })) + require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "tool_search_output"}}, + })) require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{ - "input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}}, + "input": []any{map[string]any{"type": "tool_search_output", "call_id": "call_1"}}, })) } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 700dbedf..5edf4db9 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1548,13 +1548,35 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { for _, item := range items { - if gjson.GetBytes(item, "type").String() == "function_call_output" { + if isCodexToolCallOutputItemType(gjson.GetBytes(item, "type").String()) { return true } } return false } +func openAIWSRawPayloadHasToolCallOutput(payload []byte) bool { + if len(payload) == 0 { + return false + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return false + } + if input.IsArray() { + for _, item := range input.Array() { + if isCodexToolCallOutputItemType(item.Get("type").String()) { + return true + } + } + return false + } + if input.Type == gjson.JSON { + return isCodexToolCallOutputItemType(input.Get("type").String()) + } + return false +} + func buildOpenAIWSReplayInputSequence( previousFullInput []json.RawMessage, previousFullInputExists bool, @@ -2855,7 +2877,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) - turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + turnHasFunctionCallOutput := openAIWSRawPayloadHasToolCallOutput(payload) eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 @@ -3131,7 +3153,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentTurnReplayInputExists := false skipBeforeTurn := false hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool { - if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() { + if openAIWSRawPayloadHasToolCallOutput(payload) { return true } return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput) @@ -3256,7 +3278,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) toolSignals := ToolContinuationSignals{ - HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(), + HasFunctionCallOutput: openAIWSRawPayloadHasToolCallOutput(currentPayload), } if toolSignals.HasFunctionCallOutput { var currentReqBody map[string]any diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index a4b39ddf..edb6fbcd 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1223,6 +1223,141 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledToolSearchOutputAutoAttachesPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 145, + Name: "openai-ingress-tool-search-output-auto-prev", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_tool_search_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"tool_search_output","call_id":"call_search_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_tool_search_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.Equal(t, "resp_tool_search_prev_1", gjson.Get(secondWrite, "previous_response_id").String(), "tool_search_output 缺失 previous_response_id 时应回填上一轮响应 ID") + require.Equal(t, "tool_search_output", gjson.Get(secondWrite, "input.0.type").String()) + require.Equal(t, "call_search_1", gjson.Get(secondWrite, "input.0.call_id").String()) +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index c735f50a..31c9a142 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -696,6 +696,36 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { }) } +func TestOpenAIWSRawPayloadHasToolCallOutput(t *testing.T) { + t.Parallel() + + for _, typ := range []string{ + "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output", + } { + typ := typ + t.Run(typ, func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":[{"type":"` + typ + `","call_id":"call_1","output":"ok"}]}`) + require.True(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) + } + + t.Run("object_input", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":{"type":"tool_search_output","call_id":"call_1","output":"ok"}}`) + require.True(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) + + t.Run("non_tool_output", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":[{"type":"input_text","text":"hello"}]}`) + require.False(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) +} + func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { t.Parallel()