diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 428231ee..87cfc4f8 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -1138,6 +1138,99 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout( require.False(t, result.clientDisconnect) } +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingSendsKeepaliveDuringIdle(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamKeepaliveInterval: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + time.Sleep(1200 * time.Millisecond) + _, _ = pw.Write([]byte(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":3}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":2}}`, + "", + "data: [DONE]", + "", + }, "\n"))) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 8}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), "event: ping\ndata: {\"type\": \"ping\"}\n\n") + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingKeepaliveDoesNotInterleavePartialEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamKeepaliveInterval: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}` + "\n")) + time.Sleep(1200 * time.Millisecond) + _, _ = pw.Write([]byte("\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 9}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + body := rec.Body.String() + require.NotContains(t, body, `data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}`+"\n"+"event: ping") + require.NotContains(t, body, "event: ping") + require.Contains(t, body, "data: [DONE]") +} + func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6151d78e..b8cbf715 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5357,6 +5357,22 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( intervalCh = intervalTicker.C } + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + inPartialEvent := false + for { select { case ev, ok := <-events: @@ -5422,6 +5438,10 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( } else if line == "" { // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 flusher.Flush() + lastDataAt = time.Now() + inPartialEvent = false + } else { + inPartialEvent = true } } @@ -5438,6 +5458,21 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( s.rateLimitService.HandleStreamTimeout(ctx, account, model) } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected || inPartialEvent { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + if _, err := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during keepalive ping, continue draining upstream for usage: account=%d", account.ID) + continue + } + flusher.Flush() + lastDataAt = time.Now() } } }