diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go index af8ee195..59b634e9 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -82,6 +82,7 @@ type relayState struct { terminalEventType string firstTokenMs *int turnTimingByID map[string]*relayTurnTiming + activeTurn *relayTurnTiming } type relayExitSignal struct { @@ -550,6 +551,12 @@ func observeUpstreamMessage( if ms >= 0 { state.firstTokenMs = &ms } + if state.activeTurn != nil && state.activeTurn.firstTokenMs == nil { + tms := int(now.Sub(state.activeTurn.startAt).Milliseconds()) + if tms >= 0 { + state.activeTurn.firstTokenMs = &tms + } + } } parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) observed := observedUpstreamEvent{ @@ -622,6 +629,7 @@ func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now if !ok || timing == nil || timing.startAt.IsZero() { timing = &relayTurnTiming{startAt: now} state.turnTimingByID[responseID] = timing + state.activeTurn = timing return timing } return timing @@ -636,6 +644,9 @@ func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayT return relayTurnTiming{}, false } delete(state.turnTimingByID, responseID) + if state.activeTurn == timing { + state.activeTurn = nil + } return *timing, true } diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go index ff9b7311..cdd41a05 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -750,3 +750,67 @@ func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageT func (c *errorOnWriteFrameConn) Close() error { return nil } + +func TestRelay_OnTurnComplete_RealOpenAIStream_FirstTokenMs(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.created","response":{"id":"resp_real"}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"He"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"llo"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_real","usage":{"input_tokens":2,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 10 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_real", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + + require.NotNil(t, turn.FirstTokenMs, "per-turn FirstTokenMs must be captured for real OpenAI streams") + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + + require.Less(t, + int64(*turn.FirstTokenMs), + turn.Duration.Milliseconds(), + "per-turn FirstTokenMs (%dms) should be strictly less than Duration (%dms); "+ + "equality indicates the bug where first_token is mistakenly stamped on the terminal event", + *turn.FirstTokenMs, turn.Duration.Milliseconds(), + ) + + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, *result.FirstTokenMs, 0) +}