Merge pull request #2747 from siyuan-123/fix/ws-tool-output-continuation
修复 WS 协议下工具输出续链识别问题
This commit is contained in:
commit
2fb9fb2f71
@ -20,6 +20,32 @@ type FunctionCallOutputValidation struct {
|
||||
HasItemReferenceForAllCallIDs bool
|
||||
}
|
||||
|
||||
func isCodexToolCallContextItemType(typ string) bool {
|
||||
switch strings.TrimSpace(typ) {
|
||||
case "tool_call",
|
||||
"function_call",
|
||||
"local_shell_call",
|
||||
"tool_search_call",
|
||||
"custom_tool_call",
|
||||
"mcp_tool_call":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isCodexToolCallOutputItemType(typ string) bool {
|
||||
switch strings.TrimSpace(typ) {
|
||||
case "function_call_output",
|
||||
"tool_search_output",
|
||||
"custom_tool_call_output",
|
||||
"mcp_tool_call_output":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
@ -53,7 +79,9 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。
|
||||
// AnalyzeToolContinuationSignals 单次遍历 input,提取工具输出/工具调用上下文/item_reference 相关信号。
|
||||
// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖 Codex 的所有工具输出
|
||||
// (function_call_output/tool_search_output/custom_tool_call_output/mcp_tool_call_output)。
|
||||
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
|
||||
signals := ToolContinuationSignals{}
|
||||
if reqBody == nil {
|
||||
@ -73,13 +101,13 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "tool_call", "function_call":
|
||||
switch {
|
||||
case isCodexToolCallContextItemType(itemType):
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
signals.HasToolCallContext = true
|
||||
}
|
||||
case "function_call_output":
|
||||
case isCodexToolCallOutputItemType(itemType):
|
||||
signals.HasFunctionCallOutput = true
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
@ -91,7 +119,7 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
|
||||
callIDs = make(map[string]struct{})
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
case itemType == "item_reference":
|
||||
signals.HasItemReference = true
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
@ -123,9 +151,10 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign
|
||||
}
|
||||
|
||||
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
|
||||
// 1) 无 function_call_output 直接返回
|
||||
// 2) 若已存在 tool_call/function_call 上下文则提前返回
|
||||
// 1) 无工具输出直接返回
|
||||
// 2) 若已存在工具调用上下文则提前返回
|
||||
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
|
||||
// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖所有 Codex 工具输出。
|
||||
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
|
||||
result := FunctionCallOutputValidation{}
|
||||
if reqBody == nil {
|
||||
@ -142,10 +171,10 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
switch {
|
||||
case isCodexToolCallOutputItemType(itemType):
|
||||
result.HasFunctionCallOutput = true
|
||||
case "tool_call", "function_call":
|
||||
case isCodexToolCallContextItemType(itemType):
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
result.HasToolCallContext = true
|
||||
@ -168,8 +197,8 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
switch {
|
||||
case isCodexToolCallOutputItemType(itemType):
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
@ -177,7 +206,7 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
|
||||
continue
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
case itemType == "item_reference":
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
@ -201,24 +230,25 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu
|
||||
return result
|
||||
}
|
||||
|
||||
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||
// HasFunctionCallOutput 判断 input 是否包含任意 Codex 工具输出,用于触发续链校验。
|
||||
// 名称保留 function_call_output 是为了兼容既有调用点。
|
||||
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
|
||||
}
|
||||
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的工具调用上下文,
|
||||
// 用于判断工具输出是否具备可关联的上下文。
|
||||
func HasToolCallContext(reqBody map[string]any) bool {
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
|
||||
}
|
||||
|
||||
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||
// FunctionCallOutputCallIDs 提取 input 中工具输出的 call_id 集合。
|
||||
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
|
||||
}
|
||||
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的工具输出。
|
||||
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
|
||||
}
|
||||
|
||||
@ -38,41 +38,57 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutput(t *testing.T) {
|
||||
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||
// 所有 Codex 工具输出都应视为续链输出,避免 WS 续链时丢失 previous_response_id。
|
||||
require.False(t, HasFunctionCallOutput(nil))
|
||||
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
for _, typ := range []string{
|
||||
"function_call_output",
|
||||
"tool_search_output",
|
||||
"custom_tool_call_output",
|
||||
"mcp_tool_call_output",
|
||||
} {
|
||||
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": []any{map[string]any{"type": typ}},
|
||||
}), typ)
|
||||
}
|
||||
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": "text",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHasToolCallContext(t *testing.T) {
|
||||
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||
// 工具调用上下文必须包含 call_id,才能作为可关联上下文。
|
||||
require.False(t, HasToolCallContext(nil))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||
}))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||
}))
|
||||
for _, typ := range []string{
|
||||
"tool_call",
|
||||
"function_call",
|
||||
"local_shell_call",
|
||||
"tool_search_call",
|
||||
"custom_tool_call",
|
||||
"mcp_tool_call",
|
||||
} {
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": typ, "call_id": "call_1"}},
|
||||
}), typ)
|
||||
}
|
||||
require.False(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call"}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||
// 仅提取非空 call_id,去重后返回。
|
||||
// 仅提取工具输出的非空 call_id,去重后返回。
|
||||
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
map[string]any{"type": "tool_search_output", "call_id": "call_search"},
|
||||
map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom"},
|
||||
map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp"},
|
||||
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
})
|
||||
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||
require.ElementsMatch(t, []string{"call_1", "call_search", "call_custom", "call_mcp"}, callIDs)
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||
@ -80,8 +96,11 @@ func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_search_output"}},
|
||||
}))
|
||||
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||
"input": []any{map[string]any{"type": "tool_search_output", "call_id": "call_1"}},
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@ -1548,13 +1548,35 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
|
||||
|
||||
func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
|
||||
for _, item := range items {
|
||||
if gjson.GetBytes(item, "type").String() == "function_call_output" {
|
||||
if isCodexToolCallOutputItemType(gjson.GetBytes(item, "type").String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func openAIWSRawPayloadHasToolCallOutput(payload []byte) bool {
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
input := gjson.GetBytes(payload, "input")
|
||||
if !input.Exists() {
|
||||
return false
|
||||
}
|
||||
if input.IsArray() {
|
||||
for _, item := range input.Array() {
|
||||
if isCodexToolCallOutputItemType(item.Get("type").String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if input.Type == gjson.JSON {
|
||||
return isCodexToolCallOutputItemType(input.Get("type").String())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildOpenAIWSReplayInputSequence(
|
||||
previousFullInput []json.RawMessage,
|
||||
previousFullInputExists bool,
|
||||
@ -2855,7 +2877,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
|
||||
turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
|
||||
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
|
||||
turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
|
||||
turnHasFunctionCallOutput := openAIWSRawPayloadHasToolCallOutput(payload)
|
||||
eventCount := 0
|
||||
tokenEventCount := 0
|
||||
terminalEventCount := 0
|
||||
@ -3131,7 +3153,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
currentTurnReplayInputExists := false
|
||||
skipBeforeTurn := false
|
||||
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
|
||||
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
|
||||
if openAIWSRawPayloadHasToolCallOutput(payload) {
|
||||
return true
|
||||
}
|
||||
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
|
||||
@ -3256,7 +3278,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||||
toolSignals := ToolContinuationSignals{
|
||||
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
|
||||
HasFunctionCallOutput: openAIWSRawPayloadHasToolCallOutput(currentPayload),
|
||||
}
|
||||
if toolSignals.HasFunctionCallOutput {
|
||||
var currentReqBody map[string]any
|
||||
|
||||
@ -1223,6 +1223,141 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
||||
require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledToolSearchOutputAutoAttachesPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 145,
|
||||
Name: "openai-ingress-tool-search-output-auto-prev",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_tool_search_prev_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"tool_search_output","call_id":"call_search_1","output":"ok"}]}`)
|
||||
secondTurn := readMessage()
|
||||
require.Equal(t, "resp_tool_search_prev_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
secondWrite := requestToJSONString(captureConn.writes[1])
|
||||
require.Equal(t, "resp_tool_search_prev_1", gjson.Get(secondWrite, "previous_response_id").String(), "tool_search_output 缺失 previous_response_id 时应回填上一轮响应 ID")
|
||||
require.Equal(t, "tool_search_output", gjson.Get(secondWrite, "input.0.type").String())
|
||||
require.Equal(t, "call_search_1", gjson.Get(secondWrite, "input.0.call_id").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -696,6 +696,36 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSRawPayloadHasToolCallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, typ := range []string{
|
||||
"function_call_output",
|
||||
"tool_search_output",
|
||||
"custom_tool_call_output",
|
||||
"mcp_tool_call_output",
|
||||
} {
|
||||
typ := typ
|
||||
t.Run(typ, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
payload := []byte(`{"input":[{"type":"` + typ + `","call_id":"call_1","output":"ok"}]}`)
|
||||
require.True(t, openAIWSRawPayloadHasToolCallOutput(payload))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("object_input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
payload := []byte(`{"input":{"type":"tool_search_output","call_id":"call_1","output":"ok"}}`)
|
||||
require.True(t, openAIWSRawPayloadHasToolCallOutput(payload))
|
||||
})
|
||||
|
||||
t.Run("non_tool_output", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
payload := []byte(`{"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
require.False(t, openAIWSRawPayloadHasToolCallOutput(payload))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user