Merge pull request #2799 from siyuan-123/fix/ws-rate-limit-failover

修复 OpenAI WS 限额时不自动切换账号
This commit is contained in:
Wesley Liddick 2026-05-27 15:14:28 +08:00 committed by GitHub
commit 2387cf9934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 687 additions and 198 deletions

View File

@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
releaseTurnSlots := func() { releaseAccountSlot := func() {
if currentAccountRelease != nil { if currentAccountRelease != nil {
currentAccountRelease() currentAccountRelease()
currentAccountRelease = nil currentAccountRelease = nil
} }
}
releaseTurnSlots := func() {
releaseAccountSlot()
if currentUserRelease != nil { if currentUserRelease != nil {
currentUserRelease() currentUserRelease()
currentUserRelease = nil currentUserRelease = nil
@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
return return
} }
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
ensureUserSlotHeld := func() bool {
if currentUserRelease != nil {
return true
}
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return false
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return false
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
return true
}
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
@ -1246,23 +1266,41 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
firstMessage, firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
) )
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
for {
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx, ctx,
apiKey.GroupID, apiKey.GroupID,
previousResponseID, previousResponseID,
sessionHash, sessionHash,
reqModel, reqModel,
nil, failedAccountIDs,
service.OpenAIUpstreamTransportResponsesWebsocketV2, service.OpenAIUpstreamTransportResponsesWebsocketV2,
false, false,
) )
if err != nil { if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) reqLog.Warn("openai.websocket_account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return return
} }
if selection == nil || selection.Account == nil { if selection == nil || selection.Account == nil {
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return return
} }
@ -1418,6 +1456,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
} }
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
releaseAccountSlot()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
h.gatewayService.RecordOpenAIAccountSwitch()
reqLog.Warn("openai.websocket_upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if !ensureUserSlotHeld() {
return
}
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err) closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed", reqLog.Warn("openai.websocket_proxy_failed",
@ -1435,6 +1501,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
return return
} }
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
return
}
} }
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow() _ = conn.CloseNow()
} }
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
switch failoverErr.StatusCode {
case http.StatusTooManyRequests:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
case http.StatusUnauthorized, http.StatusForbidden:
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
default:
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
}
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil { if conn == nil || decision == nil {
return return

View File

@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
return &account, nil return &account, nil
} }
type openAIWSFailoverHandlerAccountRepoStub struct {
service.AccountRepository
accounts []service.Account
rateLimitedIDs []int64
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
out := make([]service.Account, 0, len(s.accounts))
for _, account := range s.accounts {
if account.Platform == platform && account.IsSchedulable() {
out = append(out, account)
}
}
return out, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
for _, account := range s.accounts {
if account.ID == id {
acc := account
return &acc, nil
}
}
return nil, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
for i := range s.accounts {
if s.accounts[i].ID == id {
reset := resetAt
s.accounts[i].RateLimitResetAt = &reset
break
}
}
return nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct { type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository service.UsageLogRepository
created chan *service.UsageLog created chan *service.UsageLog
@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
return out, nil return out, nil
} }
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
firstHitCh := make(chan []byte, 1)
secondHitCh := make(chan []byte, 1)
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
firstHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
cancelWrite()
}))
defer firstUpstream.Close()
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
secondHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
cancelWrite()
_ = conn.Close(coderws.StatusNormalClosure, "done")
}))
defer secondUpstream.Close()
groupID := int64(4202)
accounts := []service.Account{
{
ID: 9902,
Name: "openai-ws-rate-limited",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "sk-first",
"base_url": firstUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
{
ID: 9903,
Name: "openai-ws-healthy",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 2,
Credentials: map[string]any{
"api_key": "sk-second",
"base_url": secondUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
cfg.Gateway.MaxAccountSwitches = 3
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
rateLimitSvc,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
nil,
nil,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
maxAccountSwitches: 3,
}
apiKey := &service.APIKey{
ID: 1802,
GroupID: &groupID,
User: &service.User{ID: 1702, Status: service.StatusActive},
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
select {
case <-firstHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第一个上游收到首帧超时")
}
select {
case <-secondHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第二个上游收到重放首帧超时")
}
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -2782,6 +2782,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
var dialErr *openAIWSDialError var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(dialErr.ResponseHeaders),
}
} }
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError( return nil, NewOpenAIWSClientCloseError(
@ -2977,6 +2981,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
false, false,
) )
} }
if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
lease.MarkBroken()
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), upstreamMessage...),
ResponseHeaders: cloneHeader(lease.HandshakeHeaders()),
}
}
} }
isTokenEvent := isOpenAIWSTokenEvent(eventType) isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent { if isTokenEvent {

View File

@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
select { select {
case serverErr := <-serverErrCh: case serverErr := <-serverErrCh:
require.Error(t, serverErr) require.Error(t, serverErr)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, serverErr, &failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Len(t, repo.rateLimitCalls, 1) require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):

View File

@ -59,8 +59,12 @@ type RelayOptions struct {
IdleTimeout time.Duration IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType FirstMessageType coderws.MessageType
FirstMessageSent bool
StartClientAfterFirstDownstream bool
OnUsageParseFailure func(eventType string, usageRaw string) OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult) OnTurnComplete func(turn RelayTurnResult)
BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error
ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error)
OnTrace func(event RelayTraceEvent) OnTrace func(event RelayTraceEvent)
Now func() time.Time Now func() time.Time
} }
@ -170,6 +174,14 @@ func Relay(
MessageType: relayMessageTypeString(firstMessageType), MessageType: relayMessageTypeString(firstMessageType),
}) })
if options.FirstMessageSent {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_skipped",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
} else {
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt) result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
@ -181,18 +193,28 @@ func Relay(
}) })
return result, &RelayExit{Stage: "write_upstream", Err: err} return result, &RelayExit{Stage: "write_upstream", Err: err}
} }
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok", Stage: "write_first_message_ok",
Direction: "client_to_upstream", Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType), MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage), PayloadBytes: len(firstClientMessage),
}) })
}
clientToUpstreamFrames.Add(1)
markActivity() markActivity()
exitCh := make(chan relayExitSignal, 3) exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{} dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) clientReaderStarted := atomic.Bool{}
startClientReader := func() {
if !clientReaderStarted.CompareAndSwap(false, true) {
return
}
go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
}
if !options.StartClientAfterFirstDownstream {
startClientReader()
}
go runUpstreamToClient( go runUpstreamToClient(
relayCtx, relayCtx,
upstreamConn, upstreamConn,
@ -202,6 +224,12 @@ func Relay(
state, state,
options.OnUsageParseFailure, options.OnUsageParseFailure,
options.OnTurnComplete, options.OnTurnComplete,
options.BeforeWriteClient,
func() {
if options.StartClientAfterFirstDownstream {
startClientReader()
}
},
&dropDownstreamWrites, &dropDownstreamWrites,
upstreamToClientFrames, upstreamToClientFrames,
droppedDownstreamFrames, droppedDownstreamFrames,
@ -230,8 +258,10 @@ func Relay(
} else { } else {
relayCancel() relayCancel()
_ = upstreamConn.Close() _ = upstreamConn.Close()
if clientReaderStarted.Load() {
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
} }
}
if hasSecondExit { if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
@ -250,6 +280,14 @@ func Relay(
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load() result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
if firstExit.stage == "read_client" && firstExit.graceful { if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected" stage := "client_disconnected"
exitErr := firstExit.err exitErr := firstExit.err
@ -310,6 +348,14 @@ func Relay(
WroteDownstream: combinedWroteDownstream, WroteDownstream: combinedWroteDownstream,
} }
} }
if options.FirstMessageSent {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete", Stage: "relay_complete",
Graceful: true, Graceful: true,
@ -322,14 +368,20 @@ func Relay(
func runClientToUpstream( func runClientToUpstream(
ctx context.Context, ctx context.Context,
clientConn FrameConn, clientConn FrameConn,
readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error),
writeUpstream func(msgType coderws.MessageType, payload []byte) error, writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(), markActivity func(),
forwardedFrames *atomic.Int64, forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent), onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal, exitCh chan<- relayExitSignal,
) { ) {
if readClientFrame == nil {
readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) {
return conn.ReadFrame(ctx)
}
}
for { for {
msgType, payload, err := clientConn.ReadFrame(ctx) msgType, payload, err := readClientFrame(ctx, clientConn)
if err != nil { if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed", Stage: "read_client_failed",
@ -368,6 +420,8 @@ func runUpstreamToClient(
state *relayState, state *relayState,
onUsageParseFailure func(eventType string, usageRaw string), onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult), onTurnComplete func(turn RelayTurnResult),
beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error,
afterWriteClient func(),
dropDownstreamWrites *atomic.Bool, dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64, forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64, droppedFrames *atomic.Int64,
@ -395,6 +449,24 @@ func runUpstreamToClient(
return return
} }
markActivity() markActivity()
if beforeWriteClient != nil {
if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "upstream_message_rejected",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{
stage: "upstream_message",
err: err,
wroteDownstream: wroteDownstream,
}
return
}
}
observedEvent := observedUpstreamEvent{} observedEvent := observedUpstreamEvent{}
switch msgType { switch msgType {
case coderws.MessageText: case coderws.MessageText:
@ -438,6 +510,9 @@ func runUpstreamToClient(
return return
} }
wroteDownstream = true wroteDownstream = true
if afterWriteClient != nil {
afterWriteClient()
}
if forwardedFrames != nil { if forwardedFrames != nil {
forwardedFrames.Add(1) forwardedFrames.Add(1)
} }

View File

@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
runClientToUpstream( runClientToUpstream(
context.Background(), context.Background(),
newPassthroughTestFrameConn(nil, true), newPassthroughTestFrameConn(nil, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil }, func(_ coderws.MessageType, _ []byte) error { return nil },
func() {}, func() {},
nil, nil,
@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{ newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true), }, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {}, func() {},
nil, nil,
@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{ newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true), }, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil }, func(_ coderws.MessageType, _ []byte) error { return nil },
func() {}, func() {},
forwarded, forwarded,
@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
nil, nil,
@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
nil, nil,
@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
dropped, dropped,

View File

@ -359,6 +359,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
statusCode, statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
) )
if statusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(handshakeHeaders),
}
}
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
} }
defer func() { defer func() {
@ -456,6 +463,34 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
cancel() cancel()
}, },
} }
upstreamFirstMessageSent := false
firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage)
cancelFirstWrite()
if firstWriteErr != nil {
return wrapOpenAIWSIngressTurnError(
"write_upstream",
fmt.Errorf("write first upstream websocket request: %w", firstWriteErr),
false,
)
}
upstreamFirstMessageSent = true
readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) {
for {
msgType, payload, readErr := conn.ReadFrame(readCtx)
if readErr != nil {
return msgType, payload, readErr
}
if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
return msgType, payload, nil
}
if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil {
return msgType, payload, writeErr
}
}
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx, Ctx: ctx,
ClientConn: policyClientConn, ClientConn: policyClientConn,
@ -465,6 +500,9 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
WriteTimeout: s.openAIWSWriteTimeout(), WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(), IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText, FirstMessageType: coderws.MessageText,
FirstMessageSent: upstreamFirstMessageSent,
StartClientAfterFirstDownstream: true,
ReadClientFrame: readNextClientFrame,
OnUsageParseFailure: func(eventType string, usageRaw string) { OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough( logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s", "usage_parse_failed event_type=%s usage_raw=%s",
@ -507,6 +545,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
hooks.AfterTurn(turnNo, turnResult, nil) hooks.AfterTurn(turnNo, turnResult, nil)
} }
}, },
BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error {
if msgType != coderws.MessageText || wroteDownstream {
return nil
}
if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" {
return nil
}
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload)
if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
return nil
}
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw)
logOpenAIWSV2Passthrough(
"relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s",
account.ID,
truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen),
)
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), payload...),
ResponseHeaders: cloneHeader(handshakeHeaders),
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) { OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough( logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",