fix: recognize codex tool outputs in ws continuation

This commit is contained in:
siyuan 2026-05-25 10:46:58 +08:00
parent 63b0631a58
commit fc66cd704a
5 changed files with 272 additions and 36 deletions

View File

@ -20,6 +20,32 @@ type FunctionCallOutputValidation struct {
HasItemReferenceForAllCallIDs bool
}
func isCodexToolCallContextItemType(typ string) bool {
switch strings.TrimSpace(typ) {
case "tool_call",
"function_call",
"local_shell_call",
"tool_search_call",
"custom_tool_call",
"mcp_tool_call":
return true
default:
return false
}
}
func isCodexToolCallOutputItemType(typ string) bool {
switch strings.TrimSpace(typ) {
case "function_call_output",
"tool_search_output",
"custom_tool_call_output",
"mcp_tool_call_output":
return true
default:
return false
}
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
@ -53,7 +79,9 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
return false
}
// AnalyzeToolContinuationSignals 单次遍历 input提取 function_call_output/tool_call/item_reference 相关信号。
// AnalyzeToolContinuationSignals 单次遍历 input提取工具输出/工具调用上下文/item_reference 相关信号。
// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖 Codex 的所有工具输出
// function_call_output/tool_search_output/custom_tool_call_output/mcp_tool_call_output
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
signals := ToolContinuationSignals{}
if reqBody == nil {
@ -73,13 +101,13 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "tool_call", "function_call":
switch {
case isCodexToolCallContextItemType(itemType):
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
signals.HasToolCallContext = true
}
case "function_call_output":
case isCodexToolCallOutputItemType(itemType):
signals.HasFunctionCallOutput = true
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
@ -91,7 +119,7 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
callIDs = make(map[string]struct{})
}
callIDs[callID] = struct{}{}
case "item_reference":
case itemType == "item_reference":
signals.HasItemReference = true
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
@ -123,9 +151,10 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
}
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
// 1) 无 function_call_output 直接返回
// 2) 若已存在 tool_call/function_call 上下文则提前返回
// 1) 无工具输出直接返回
// 2) 若已存在工具调用上下文则提前返回
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖所有 Codex 工具输出。
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
result := FunctionCallOutputValidation{}
if reqBody == nil {
@ -142,10 +171,10 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
switch {
case isCodexToolCallOutputItemType(itemType):
result.HasFunctionCallOutput = true
case "tool_call", "function_call":
case isCodexToolCallContextItemType(itemType):
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
result.HasToolCallContext = true
@ -168,8 +197,8 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
switch {
case isCodexToolCallOutputItemType(itemType):
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
@ -177,7 +206,7 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
continue
}
callIDs[callID] = struct{}{}
case "item_reference":
case itemType == "item_reference":
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
@ -201,24 +230,25 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
return result
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
// HasFunctionCallOutput 判断 input 是否包含任意 Codex 工具输出,用于触发续链校验。
// 名称保留 function_call_output 是为了兼容既有调用点。
func HasFunctionCallOutput(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
// HasToolCallContext 判断 input 是否包含带 call_id 的工具调用上下文
// 用于判断工具输出是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// FunctionCallOutputCallIDs 提取 input 中工具输出的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的工具输出
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}

View File

@ -38,41 +38,57 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
}
func TestHasFunctionCallOutput(t *testing.T) {
// 仅当 input 中存在 function_call_output 才视为续链输出
// 所有 Codex 工具输出都应视为续链输出,避免 WS 续链时丢失 previous_response_id
require.False(t, HasFunctionCallOutput(nil))
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
for _, typ := range []string{
"function_call_output",
"tool_search_output",
"custom_tool_call_output",
"mcp_tool_call_output",
} {
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": typ}},
}), typ)
}
require.False(t, HasFunctionCallOutput(map[string]any{
"input": "text",
}))
}
func TestHasToolCallContext(t *testing.T) {
// tool_call/function_call 必须包含 call_id才能作为可关联上下文。
// 工具调用上下文必须包含 call_id才能作为可关联上下文。
require.False(t, HasToolCallContext(nil))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
}))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
}))
for _, typ := range []string{
"tool_call",
"function_call",
"local_shell_call",
"tool_search_call",
"custom_tool_call",
"mcp_tool_call",
} {
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": typ, "call_id": "call_1"}},
}), typ)
}
require.False(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call"}},
}))
}
func TestFunctionCallOutputCallIDs(t *testing.T) {
// 仅提取非空 call_id去重后返回。
// 仅提取工具输出的非空 call_id去重后返回。
require.Empty(t, FunctionCallOutputCallIDs(nil))
callIDs := FunctionCallOutputCallIDs(map[string]any{
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
map[string]any{"type": "tool_search_output", "call_id": "call_search"},
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom"},
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp"},
map[string]any{"type": "function_call_output", "call_id": ""},
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
})
require.ElementsMatch(t, []string{"call_1"}, callIDs)
require.ElementsMatch(t, []string{"call_1", "call_search", "call_custom", "call_mcp"}, callIDs)
}
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
@ -80,8 +96,11 @@ func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "tool_search_output"}},
}))
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
"input": []any{map[string]any{"type": "tool_search_output", "call_id": "call_1"}},
}))
}

View File

@ -1548,13 +1548,35 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
for _, item := range items {
if gjson.GetBytes(item, "type").String() == "function_call_output" {
if isCodexToolCallOutputItemType(gjson.GetBytes(item, "type").String()) {
return true
}
}
return false
}
func openAIWSRawPayloadHasToolCallOutput(payload []byte) bool {
if len(payload) == 0 {
return false
}
input := gjson.GetBytes(payload, "input")
if !input.Exists() {
return false
}
if input.IsArray() {
for _, item := range input.Array() {
if isCodexToolCallOutputItemType(item.Get("type").String()) {
return true
}
}
return false
}
if input.Type == gjson.JSON {
return isCodexToolCallOutputItemType(input.Get("type").String())
}
return false
}
func buildOpenAIWSReplayInputSequence(
previousFullInput []json.RawMessage,
previousFullInputExists bool,
@ -2855,7 +2877,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
turnHasFunctionCallOutput := openAIWSRawPayloadHasToolCallOutput(payload)
eventCount := 0
tokenEventCount := 0
terminalEventCount := 0
@ -3131,7 +3153,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentTurnReplayInputExists := false
skipBeforeTurn := false
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
if openAIWSRawPayloadHasToolCallOutput(payload) {
return true
}
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
@ -3256,7 +3278,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
toolSignals := ToolContinuationSignals{
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
HasFunctionCallOutput: openAIWSRawPayloadHasToolCallOutput(currentPayload),
}
if toolSignals.HasFunctionCallOutput {
var currentReqBody map[string]any

View File

@ -1223,6 +1223,141 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledToolSearchOutputAutoAttachesPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 145,
Name: "openai-ingress-tool-search-output-auto-prev",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_tool_search_prev_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"tool_search_output","call_id":"call_search_1","output":"ok"}]}`)
secondTurn := readMessage()
require.Equal(t, "resp_tool_search_prev_2", gjson.GetBytes(secondTurn, "response.id").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, captureDialer.DialCount())
require.Len(t, captureConn.writes, 2)
secondWrite := requestToJSONString(captureConn.writes[1])
require.Equal(t, "resp_tool_search_prev_1", gjson.Get(secondWrite, "previous_response_id").String(), "tool_search_output 缺失 previous_response_id 时应回填上一轮响应 ID")
require.Equal(t, "tool_search_output", gjson.Get(secondWrite, "input.0.type").String())
require.Equal(t, "call_search_1", gjson.Get(secondWrite, "input.0.call_id").String())
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -696,6 +696,36 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
})
}
func TestOpenAIWSRawPayloadHasToolCallOutput(t *testing.T) {
t.Parallel()
for _, typ := range []string{
"function_call_output",
"tool_search_output",
"custom_tool_call_output",
"mcp_tool_call_output",
} {
typ := typ
t.Run(typ, func(t *testing.T) {
t.Parallel()
payload := []byte(`{"input":[{"type":"` + typ + `","call_id":"call_1","output":"ok"}]}`)
require.True(t, openAIWSRawPayloadHasToolCallOutput(payload))
})
}
t.Run("object_input", func(t *testing.T) {
t.Parallel()
payload := []byte(`{"input":{"type":"tool_search_output","call_id":"call_1","output":"ok"}}`)
require.True(t, openAIWSRawPayloadHasToolCallOutput(payload))
})
t.Run("non_tool_output", func(t *testing.T) {
t.Parallel()
payload := []byte(`{"input":[{"type":"input_text","text":"hello"}]}`)
require.False(t, openAIWSRawPayloadHasToolCallOutput(payload))
})
}
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
t.Parallel()