fix: add keepalive for Anthropic passthrough streams
This commit is contained in:
parent
f5bd25bea0
commit
164e2f610c
@ -1138,6 +1138,99 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(
|
|||||||
require.False(t, result.clientDisconnect)
|
require.False(t, result.clientDisconnect)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingSendsKeepaliveDuringIdle(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamKeepaliveInterval: 1,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: pr,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
time.Sleep(1200 * time.Millisecond)
|
||||||
|
_, _ = pw.Write([]byte(strings.Join([]string{
|
||||||
|
`data: {"type":"message_start","message":{"usage":{"input_tokens":3}}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"message_delta","usage":{"output_tokens":2}}`,
|
||||||
|
"",
|
||||||
|
"data: [DONE]",
|
||||||
|
"",
|
||||||
|
}, "\n")))
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 8}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
_ = pr.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Contains(t, rec.Body.String(), "event: ping\ndata: {\"type\": \"ping\"}\n\n")
|
||||||
|
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingKeepaliveDoesNotInterleavePartialEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamKeepaliveInterval: 1,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: pr,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}` + "\n"))
|
||||||
|
time.Sleep(1200 * time.Millisecond)
|
||||||
|
_, _ = pw.Write([]byte("\n"))
|
||||||
|
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 9}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
_ = pr.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
body := rec.Body.String()
|
||||||
|
require.NotContains(t, body, `data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}`+"\n"+"event: ping")
|
||||||
|
require.NotContains(t, body, "event: ping")
|
||||||
|
require.Contains(t, body, "data: [DONE]")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@ -5357,6 +5357,22 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
intervalCh = intervalTicker.C
|
intervalCh = intervalTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keepaliveInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||||
|
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||||
|
}
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
if keepaliveInterval > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
var keepaliveCh <-chan time.Time
|
||||||
|
if keepaliveTicker != nil {
|
||||||
|
keepaliveCh = keepaliveTicker.C
|
||||||
|
}
|
||||||
|
lastDataAt := time.Now()
|
||||||
|
inPartialEvent := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev, ok := <-events:
|
case ev, ok := <-events:
|
||||||
@ -5422,6 +5438,10 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
} else if line == "" {
|
} else if line == "" {
|
||||||
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
|
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
lastDataAt = time.Now()
|
||||||
|
inPartialEvent = false
|
||||||
|
} else {
|
||||||
|
inPartialEvent = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5438,6 +5458,21 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||||
}
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected || inPartialEvent {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during keepalive ping, continue draining upstream for usage: account=%d", account.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
lastDataAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user