From 6381f9e37d7b740459b88eae99b20188de94349e Mon Sep 17 00:00:00 2001 From: wucm667 Date: Tue, 19 May 2026 15:47:37 +0800 Subject: [PATCH] =?UTF-8?q?fix(openai):=20=E8=AF=86=E5=88=AB=E4=B8=8A?= =?UTF-8?q?=E6=B8=B8=E9=9D=99=E9=BB=98=E6=8B=92=E7=BB=9D=E5=B9=B6=E8=A7=A6?= =?UTF-8?q?=E5=8F=91=20failover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/handler/gateway_handler.go | 5 + .../gateway_handler_chat_completions.go | 5 + .../handler/gateway_handler_responses.go | 5 + .../handler/openai_chat_completions.go | 5 + .../handler/openai_gateway_handler.go | 10 + .../openai_gateway_chat_completions.go | 100 +++++- .../openai_gateway_chat_completions_raw.go | 96 ++++-- ...penai_gateway_chat_completions_raw_test.go | 158 ++++++++++ .../internal/service/openai_silent_refusal.go | 293 ++++++++++++++++++ 9 files changed, 649 insertions(+), 28 deletions(-) create mode 100644 backend/internal/service/openai_silent_refusal.go diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 723fddb4..0c88ebb4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1325,6 +1325,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody + if service.IsOpenAISilentRefusalErrorBody(responseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted) + return + } // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 60b14c4b..7d2c2b1d 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -335,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv if lastErr != nil && lastErr.StatusCode > 0 { statusCode = lastErr.StatusCode } + if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage()) + return + } h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index b8a2af8e..03246f8b 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -310,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr if lastErr != nil && lastErr.StatusCode > 0 { statusCode = lastErr.StatusCode } + if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage()) + return + } h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index aff28e0f..f78c63a2 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -178,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + writerSizeBeforeForward := c.Writer.Size() result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") forwardDurationMs := time.Since(forwardStart).Milliseconds() @@ -203,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { } else { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) // Pool mode: retry on the same account if failoverErr.RetryableOnSameAccount { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index dcd737af..d9e81d4d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -332,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + writerSizeBeforeForward := c.Writer.Size() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -356,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } else { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) // 池模式:同账号重试 if failoverErr.RetryableOnSameAccount { @@ -1604,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody + if service.IsOpenAISilentRefusalErrorBody(responseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted) + return + } // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 6cb80197..f8b23a28 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -292,7 +292,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, account, originalModel, billingModel, upstreamModel, includeUsage, startTime, len(body)) } else { result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } @@ -414,22 +414,31 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, + account *Account, originalModel string, billingModel string, upstreamModel string, includeUsage bool, startTime time.Time, + requestBodyLen int, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) } - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) state := apicompat.NewResponsesEventToChatState() state.Model = originalModel @@ -439,6 +448,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var firstTokenMs *int firstChunk := true clientDisconnected := false + clientOutputStarted := false + pendingSSE := make([]string, 0, 4) + refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen) scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -489,6 +501,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) return false } + refusalDetector.ObservePayload([]byte(payload)) // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) @@ -499,6 +512,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( chunks := apicompat.ResponsesEventToChatChunks(&event, state) if !clientDisconnected { for _, chunk := range chunks { + refusalDetector.ObserveChatChunk(chunk) sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { logger.L().Warn("openai chat_completions stream: failed to marshal chunk", @@ -507,6 +521,27 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) continue } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingSSE = append(pendingSSE, sse) + continue + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected while flushing pending chunks", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + if clientDisconnected { + break + } + } if _, err := fmt.Fprint(c.Writer, sse); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", @@ -516,7 +551,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } } } - if len(chunks) > 0 && !clientDisconnected { + if len(chunks) > 0 && !clientDisconnected && clientOutputStarted { c.Writer.Flush() } return isTerminalEvent @@ -525,10 +560,32 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( finalizeStream := func() (*OpenAIForwardResult, error) { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { + refusalDetector.ObserveChatChunk(chunk) sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingSSE = append(pendingSSE, sse) + continue + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during pending final flush", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + if clientDisconnected { + break + } + } if _, err := fmt.Fprint(c.Writer, sse); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected during final flush", @@ -538,14 +595,35 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } } } + if !clientDisconnected && !clientOutputStarted { + if refusalDetector.IsSilentRefusal() { + return nil, newOpenAISilentRefusalFailoverError(c, account, requestID) + } + if len(pendingSSE) > 0 { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final pending flush", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + } + } // Send [DONE] sentinel if !clientDisconnected { + writeStreamHeaders() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected during done flush", zap.String("request_id", requestID), ) } + clientOutputStarted = !clientDisconnected } if !clientDisconnected { c.Writer.Flush() @@ -702,10 +780,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( if clientDisconnected { continue } + if refusalDetector.Enabled() && !clientOutputStarted { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } // Send SSE comment as keepalive + writeStreamHeaders() if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index 0203b94a..c585290e 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -220,7 +220,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( // 8. Forward response if clientStream { - return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + return s.streamRawChatCompletions(c, resp, account, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime, len(body)) } return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) } @@ -234,23 +234,32 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( func (s *OpenAIGatewayService) streamRawChatCompletions( c *gin.Context, resp *http.Response, + account *Account, originalModel string, billingModel string, upstreamModel string, reasoningEffort *string, serviceTier *string, startTime time.Time, + requestBodyLen int, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) } - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -262,9 +271,45 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( var usage OpenAIUsage var firstTokenMs *int clientDisconnected := false + clientOutputStarted := false + pendingLines := make([]string, 0, 8) + refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen) + + writeLine := func(line string) { + if clientDisconnected { + return + } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingLines = append(pendingLines, line) + return + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingLines { + if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + return + } + } + pendingLines = pendingLines[:0] + clientOutputStarted = true + } + if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + } + } for scanner.Scan() { line := scanner.Text() + refusalDetector.ObserveSSELine(line) if payload, ok := extractOpenAISSEDataLine(line); ok { trimmedPayload := strings.TrimSpace(payload) if trimmedPayload != "[DONE]" { @@ -279,22 +324,14 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( } } - if !clientDisconnected { - if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { - clientDisconnected = true - logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", - zap.Error(werr), - zap.String("request_id", requestID), - ) - } - } + writeLine(line) if line == "" { - if !clientDisconnected { + if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } continue } - if !clientDisconnected { + if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } } @@ -306,6 +343,27 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( zap.String("request_id", requestID), ) } + } else if !clientDisconnected && !clientOutputStarted { + if refusalDetector.IsSilentRefusal() { + return nil, newOpenAISilentRefusalFailoverError(c, account, requestID) + } + if len(pendingLines) > 0 { + writeStreamHeaders() + for _, pending := range pendingLines { + if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected during final flush", + zap.Error(werr), + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() + clientOutputStarted = true + } + } } return &OpenAIForwardResult{ diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go index 46ddbc09..91f0fd14 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go @@ -5,6 +5,7 @@ package service import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" @@ -120,6 +121,157 @@ func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDown require.Contains(t, rec.Body.String(), "data: [DONE]") } +func TestForwardAsRawChatCompletions_SilentRefusalTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_silent"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.Nil(t, result) + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr)) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.True(t, IsOpenAISilentRefusalErrorBody(failoverErr.ResponseBody)) + require.False(t, c.Writer.Written(), "silent refusal must not commit a 200 response before failover") + require.Empty(t, rec.Body.String()) +} + +func TestForwardAsRawChatCompletions_SilentRefusalToolCallsExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"lookup","arguments":""}}]}}]}`, + "", + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_tool"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"tool_calls"`) + require.Contains(t, rec.Body.String(), `"finish_reason":"tool_calls"`) +} + +func TestHandleChatStreamingResponse_SilentRefusalReasoningSummaryExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_reasoning","model":"gpt-5.5"}}`, + "", + `data: {"type":"response.reasoning_summary_text.delta","delta":"thinking only"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_reasoning","model":"gpt-5.5","status":"completed"}}`, + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_reasoning"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + } + svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()} + + result, err := svc.handleChatStreamingResponse( + resp, + c, + rawChatCompletionsTestAccount(), + "gpt-5.5", + "gpt-5.5", + "gpt-5.5", + false, + time.Now(), + openAISilentRefusalMinRequestBodyBytes, + ) + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"reasoning_content":"thinking only"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestForwardAsRawChatCompletions_SilentRefusalNormalContentExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ok"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"content":"ok"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) { gin.SetMode(gin.TestMode) @@ -303,3 +455,9 @@ func rawChatCompletionsTestAccount() *Account { }, } } + +func largeRawChatCompletionsBody() []byte { + return []byte(`{"model":"gpt-5.5","messages":[{"role":"user","content":"` + + strings.Repeat("x", openAISilentRefusalMinRequestBodyBytes) + + `"}],"stream":true}`) +} diff --git a/backend/internal/service/openai_silent_refusal.go b/backend/internal/service/openai_silent_refusal.go new file mode 100644 index 00000000..27b71b75 --- /dev/null +++ b/backend/internal/service/openai_silent_refusal.go @@ -0,0 +1,293 @@ +package service + +import ( + "bytes" + "encoding/json" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +const ( + openAISilentRefusalMinRequestBodyBytes = 64 * 1024 + openAISilentRefusalErrorCode = "openai_silent_refusal" + openAISilentRefusalUpstreamMessage = "OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage" + openAISilentRefusalClientMessage = "Upstream returned an empty completion without usage; no fallback account was available" +) + +type openAIChatSilentRefusalDetector struct { + enabled bool + sawContent bool + sawToolCall bool + sawFunctionCall bool + sawUsage bool + sawError bool + sawReasoning bool + sawFinish bool + finishReason string +} + +func newOpenAIChatSilentRefusalDetector(requestBodyLen int) *openAIChatSilentRefusalDetector { + return &openAIChatSilentRefusalDetector{ + enabled: requestBodyLen >= openAISilentRefusalMinRequestBodyBytes, + } +} + +func (d *openAIChatSilentRefusalDetector) Enabled() bool { + return d != nil && d.enabled +} + +func (d *openAIChatSilentRefusalDetector) ObserveSSELine(line string) { + if d == nil || !d.enabled { + return + } + if eventType, ok := extractOpenAISSEEventLine(line); ok { + d.observeEventType(eventType) + return + } + if payload, ok := extractOpenAISSEDataLine(line); ok { + d.ObservePayload([]byte(payload)) + } +} + +func (d *openAIChatSilentRefusalDetector) ObservePayload(payload []byte) { + if d == nil || !d.enabled { + return + } + payload = bytes.TrimSpace(payload) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + return + } + if !gjson.ValidBytes(payload) { + return + } + + eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + d.observeEventType(eventType) + + if gjson.GetBytes(payload, "error").Exists() { + d.sawError = true + } + if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() { + d.sawUsage = true + } + if usage := gjson.GetBytes(payload, "response.usage"); usage.Exists() && usage.IsObject() { + d.sawUsage = true + } + + d.observeChatChoicesPayload(payload) + d.observeResponsesPayload(payload, eventType) +} + +func (d *openAIChatSilentRefusalDetector) ObserveChatChunk(chunk apicompat.ChatCompletionsChunk) { + if d == nil || !d.enabled { + return + } + if chunk.Usage != nil { + d.sawUsage = true + } + for _, choice := range chunk.Choices { + if choice.FinishReason != nil { + d.observeFinishReason(*choice.FinishReason) + } + delta := choice.Delta + if delta.Content != nil && *delta.Content != "" { + d.sawContent = true + } + if delta.ReasoningContent != nil { + d.sawReasoning = true + } + if len(delta.ToolCalls) > 0 { + d.sawToolCall = true + } + } +} + +func (d *openAIChatSilentRefusalDetector) ShouldReleaseClientOutput() bool { + if d == nil || !d.enabled { + return true + } + if d.sawContent || d.sawToolCall || d.sawFunctionCall || d.sawUsage || d.sawError || d.sawReasoning { + return true + } + return d.sawFinish && d.finishReason != "" && d.finishReason != "stop" +} + +func (d *openAIChatSilentRefusalDetector) IsSilentRefusal() bool { + if d == nil || !d.enabled { + return false + } + return !d.sawContent && + !d.sawToolCall && + !d.sawFunctionCall && + !d.sawUsage && + !d.sawError && + !d.sawReasoning && + d.sawFinish && + d.finishReason == "stop" +} + +func (d *openAIChatSilentRefusalDetector) observeEventType(eventType string) { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return + } + if eventType == "error" || eventType == "response.failed" { + d.sawError = true + } + if strings.Contains(eventType, "reasoning") || strings.Contains(eventType, "reasoning_summary") { + d.sawReasoning = true + } +} + +func (d *openAIChatSilentRefusalDetector) observeFinishReason(reason string) { + reason = strings.TrimSpace(reason) + if reason == "" { + return + } + d.sawFinish = true + d.finishReason = reason +} + +func (d *openAIChatSilentRefusalDetector) observeChatChoicesPayload(payload []byte) { + choices := gjson.GetBytes(payload, "choices") + if !choices.Exists() || !choices.IsArray() { + return + } + for _, choice := range choices.Array() { + if finish := choice.Get("finish_reason"); finish.Exists() { + d.observeFinishReason(finish.String()) + } + delta := choice.Get("delta") + if !delta.Exists() { + continue + } + if content := delta.Get("content"); content.Exists() && content.String() != "" { + d.sawContent = true + } + if delta.Get("tool_calls").Exists() { + d.sawToolCall = true + } + if delta.Get("function_call").Exists() { + d.sawFunctionCall = true + } + if delta.Get("reasoning").Exists() || + delta.Get("reasoning_content").Exists() || + delta.Get("reasoning_summary").Exists() { + d.sawReasoning = true + } + } +} + +func (d *openAIChatSilentRefusalDetector) observeResponsesPayload(payload []byte, eventType string) { + switch eventType { + case "response.output_text.delta": + if gjson.GetBytes(payload, "delta").String() != "" { + d.sawContent = true + } + case "response.output_item.added": + switch strings.TrimSpace(gjson.GetBytes(payload, "item.type").String()) { + case "function_call": + d.sawToolCall = true + case "reasoning": + d.sawReasoning = true + } + case "response.function_call_arguments.delta": + d.sawToolCall = true + case "response.reasoning_summary_text.delta", "response.reasoning_summary_text.done": + d.sawReasoning = true + case "response.completed", "response.done": + d.observeFinishReason("stop") + case "response.incomplete": + d.observeFinishReason("length") + case "response.failed": + d.sawError = true + } + + if output := gjson.GetBytes(payload, "response.output"); output.Exists() && output.IsArray() { + for _, item := range output.Array() { + switch strings.TrimSpace(item.Get("type").String()) { + case "function_call": + d.sawToolCall = true + case "reasoning": + d.sawReasoning = true + case "message": + d.observeResponseMessageItem(item) + } + } + } +} + +func (d *openAIChatSilentRefusalDetector) observeResponseMessageItem(item gjson.Result) { + content := item.Get("content") + if !content.Exists() || !content.IsArray() { + return + } + for _, part := range content.Array() { + if part.Get("text").String() != "" { + d.sawContent = true + return + } + } +} + +func newOpenAISilentRefusalFailoverError(c *gin.Context, account *Account, upstreamRequestID string) *UpstreamFailoverError { + accountID := int64(0) + accountName := "" + platform := PlatformOpenAI + if account != nil { + accountID = account.ID + accountName = account.Name + platform = account.Platform + } + + setOpsUpstreamError(c, http.StatusBadGateway, openAISilentRefusalUpstreamMessage, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: platform, + AccountID: accountID, + AccountName: accountName, + UpstreamStatusCode: http.StatusBadGateway, + UpstreamRequestID: upstreamRequestID, + Kind: "failover", + Message: openAISilentRefusalUpstreamMessage, + }) + + headers := http.Header{} + if strings.TrimSpace(upstreamRequestID) != "" { + headers.Set("x-request-id", strings.TrimSpace(upstreamRequestID)) + } + return &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: openAISilentRefusalErrorBody(), + ResponseHeaders: headers, + } +} + +func openAISilentRefusalErrorBody() []byte { + body, err := json.Marshal(map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "code": openAISilentRefusalErrorCode, + "message": openAISilentRefusalUpstreamMessage, + }, + }) + if err != nil { + return []byte(`{"error":{"type":"upstream_error","code":"openai_silent_refusal","message":"OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage"}}`) + } + return body +} + +// IsOpenAISilentRefusalErrorBody reports whether a failover body was produced +// by the OpenAI silent-refusal detector. +func IsOpenAISilentRefusalErrorBody(body []byte) bool { + return strings.TrimSpace(gjson.GetBytes(body, "error.code").String()) == openAISilentRefusalErrorCode +} + +// OpenAISilentRefusalClientMessage returns the exhausted-failover client message +// for OpenAI silent refusals. +func OpenAISilentRefusalClientMessage() string { + return openAISilentRefusalClientMessage +}