fix(openai): avoid inferred WS continuation on explicit tool replay
This commit is contained in:
parent
4d676dddd1
commit
28dc34b6a3
@ -1366,16 +1366,25 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
|
||||
func shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled bool,
|
||||
turn int,
|
||||
hasFunctionCallOutput bool,
|
||||
signals ToolContinuationSignals,
|
||||
currentPreviousResponseID string,
|
||||
expectedPreviousResponseID string,
|
||||
) bool {
|
||||
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
|
||||
if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(currentPreviousResponseID) != "" {
|
||||
return false
|
||||
}
|
||||
if signals.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
// If the client already sent tool-call context or item_reference anchors,
|
||||
// treat this as a full replay / self-contained continuation payload rather
|
||||
// than downgrading it into an inferred delta continuation.
|
||||
if signals.HasToolCallContext || signals.HasItemReferenceForAllCallIDs {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(expectedPreviousResponseID) != ""
|
||||
}
|
||||
|
||||
@ -3179,13 +3188,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
skipBeforeTurn = false
|
||||
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
|
||||
expectedPrev := strings.TrimSpace(lastTurnResponseID)
|
||||
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||||
toolSignals := ToolContinuationSignals{
|
||||
HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
|
||||
}
|
||||
if toolSignals.HasFunctionCallOutput {
|
||||
var currentReqBody map[string]any
|
||||
if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
|
||||
toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
|
||||
}
|
||||
}
|
||||
hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
|
||||
// store=false + function_call_output 场景必须有续链锚点。
|
||||
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
|
||||
if shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
storeDisabled,
|
||||
turn,
|
||||
hasFunctionCallOutput,
|
||||
toolSignals,
|
||||
currentPreviousResponseID,
|
||||
expectedPrev,
|
||||
) {
|
||||
|
||||
@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{captureConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Name: "openai-ingress-tool-context",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||
secondTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenItemReferencesPresent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{captureConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Name: "openai-ingress-item-reference",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
|
||||
secondTurn := readMessage()
|
||||
require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
|
||||
|
||||
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.NoError(t, serverErr)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, captureDialer.DialCount())
|
||||
require.Len(t, captureConn.writes, 2)
|
||||
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 item_reference 锚点时不应自动补齐 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||
|
||||
@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
name string
|
||||
storeDisabled bool
|
||||
turn int
|
||||
hasFunctionCallOutput bool
|
||||
signals ToolContinuationSignals
|
||||
currentPreviousResponse string
|
||||
expectedPrevious string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: false,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{},
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_request_already_has_previous_response_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
name: "skip_when_request_already_has_previous_response_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
currentPreviousResponse: "resp_client",
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true},
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_tool_call_context_already_present",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_item_reference_already_covers_all_call_ids",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_function_call_output_missing_call_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
|
||||
expectedPrevious: "resp_2",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
tt.storeDisabled,
|
||||
tt.turn,
|
||||
tt.hasFunctionCallOutput,
|
||||
tt.signals,
|
||||
tt.currentPreviousResponse,
|
||||
tt.expectedPrevious,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user