fix: add keepalive for Anthropic passthrough streams
This commit is contained in:
parent
f5bd25bea0
commit
164e2f610c
@ -1138,6 +1138,99 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(
|
||||
require.False(t, result.clientDisconnect)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingSendsKeepaliveDuringIdle(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 1,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: pr,
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
_, _ = pw.Write([]byte(strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"usage":{"input_tokens":3}}}`,
|
||||
"",
|
||||
`data: {"type":"message_delta","usage":{"output_tokens":2}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")))
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 8}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, rec.Body.String(), "event: ping\ndata: {\"type\": \"ping\"}\n\n")
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingKeepaliveDoesNotInterleavePartialEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamKeepaliveInterval: 1,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: pr,
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}` + "\n"))
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
_, _ = pw.Write([]byte("\n"))
|
||||
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 9}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
body := rec.Body.String()
|
||||
require.NotContains(t, body, `data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}`+"\n"+"event: ping")
|
||||
require.NotContains(t, body, "event: ping")
|
||||
require.Contains(t, body, "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -5357,6 +5357,22 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
inPartialEvent := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
@ -5422,6 +5438,10 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
} else if line == "" {
|
||||
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
|
||||
flusher.Flush()
|
||||
lastDataAt = time.Now()
|
||||
inPartialEvent = false
|
||||
} else {
|
||||
inPartialEvent = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -5438,6 +5458,21 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if clientDisconnected || inPartialEvent {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during keepalive ping, continue draining upstream for usage: account=%d", account.ID)
|
||||
continue
|
||||
}
|
||||
flusher.Flush()
|
||||
lastDataAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user