fix: add keepalive for Anthropic passthrough streams

This commit is contained in:
lyen1688 2026-05-18 18:41:25 +08:00
parent f5bd25bea0
commit 164e2f610c
2 changed files with 128 additions and 0 deletions

View File

@ -1138,6 +1138,99 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(
require.False(t, result.clientDisconnect)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingSendsKeepaliveDuringIdle(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamKeepaliveInterval: 1,
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: pr,
}
done := make(chan struct{})
go func() {
defer close(done)
time.Sleep(1200 * time.Millisecond)
_, _ = pw.Write([]byte(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":3}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":2}}`,
"",
"data: [DONE]",
"",
}, "\n")))
_ = pw.Close()
}()
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 8}, time.Now(), "claude-3-7-sonnet-20250219")
_ = pr.Close()
<-done
require.NoError(t, err)
require.NotNil(t, result)
require.Contains(t, rec.Body.String(), "event: ping\ndata: {\"type\": \"ping\"}\n\n")
require.Contains(t, rec.Body.String(), "data: [DONE]")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingKeepaliveDoesNotInterleavePartialEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
StreamKeepaliveInterval: 1,
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: pr,
}
done := make(chan struct{})
go func() {
defer close(done)
_, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}` + "\n"))
time.Sleep(1200 * time.Millisecond)
_, _ = pw.Write([]byte("\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
_ = pw.Close()
}()
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 9}, time.Now(), "claude-3-7-sonnet-20250219")
_ = pr.Close()
<-done
require.NoError(t, err)
require.NotNil(t, result)
body := rec.Body.String()
require.NotContains(t, body, `data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}`+"\n"+"event: ping")
require.NotContains(t, body, "event: ping")
require.Contains(t, body, "data: [DONE]")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@ -5357,6 +5357,22 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
inPartialEvent := false
for {
select {
case ev, ok := <-events:
@ -5422,6 +5438,10 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
} else if line == "" {
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
flusher.Flush()
lastDataAt = time.Now()
inPartialEvent = false
} else {
inPartialEvent = true
}
}
@ -5438,6 +5458,21 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || inPartialEvent {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during keepalive ping, continue draining upstream for usage: account=%d", account.ID)
continue
}
flusher.Flush()
lastDataAt = time.Now()
}
}
}