diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index dedbce1e..023217b2 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string func shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled bool, turn int, - hasFunctionCallOutput bool, + signals ToolContinuationSignals, currentPreviousResponseID string, expectedPreviousResponseID string, ) bool { - if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput { return false } if strings.TrimSpace(currentPreviousResponseID) != "" { return false } + if signals.HasFunctionCallOutputMissingCallID { + return false + } + // If the client already sent tool-call context or item_reference anchors, + // treat this as a full replay / self-contained continuation payload rather + // than downgrading it into an inferred delta continuation. + if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs { + return false + } return strings.TrimSpace(expectedPreviousResponseID) != "" } @@ -3179,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( skipBeforeTurn = false currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) - hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + toolSignals := ToolContinuationSignals{ + HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(), + } + if toolSignals.HasFunctionCallOutput { + var currentReqBody map[string]any + if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil { + toolSignals = AnalyzeToolContinuationSignals(currentReqBody) + } + } + hasFunctionCallOutput := toolSignals.HasFunctionCallOutput // store=false + function_call_output 场景必须有续链锚点。 // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 if shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled, turn, - hasFunctionCallOutput, + toolSignals, currentPreviousResponseID, expectedPrev, ) { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 6bf9a9ff..701f069a 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-tool-context", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-item-reference", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { gin.SetMode(gin.TestMode) prevPreflightPingIdle := openAIWSIngressPreflightPingIdle diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index ff35cb01..08597f0c 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { name string storeDisabled bool turn int - hasFunctionCallOutput bool + signals ToolContinuationSignals currentPreviousResponse string expectedPrevious string want bool }{ { - name: "infer_when_all_conditions_match", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: true, + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: true, }, { - name: "skip_when_store_enabled", - storeDisabled: false, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_on_first_turn", - storeDisabled: true, - turn: 1, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_without_function_call_output", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: false, - expectedPrevious: "resp_1", - want: false, + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{}, + expectedPrevious: "resp_1", + want: false, }, { name: "skip_when_request_already_has_previous_response_id", storeDisabled: true, turn: 2, - hasFunctionCallOutput: true, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, currentPreviousResponse: "resp_client", expectedPrevious: "resp_1", want: false, }, { - name: "skip_when_last_turn_response_id_missing", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "", - want: false, + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "", + want: false, }, { - name: "trim_whitespace_before_judgement", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: " resp_2 ", - want: true, + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: " resp_2 ", + want: true, + }, + { + name: "skip_when_tool_call_context_already_present", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true}, + expectedPrevious: "resp_2", + want: false, + }, + { + name: "skip_when_item_reference_already_covers_all_call_ids", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, + expectedPrevious: "resp_2", + want: false, + }, + { + name: "skip_when_function_call_output_missing_call_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true}, + expectedPrevious: "resp_2", + want: false, }, } @@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { got := shouldInferIngressFunctionCallOutputPreviousResponseID( tt.storeDisabled, tt.turn, - tt.hasFunctionCallOutput, + tt.signals, tt.currentPreviousResponse, tt.expectedPrevious, )