Merge pull request #2799 from siyuan-123/fix/ws-rate-limit-failover

修复 OpenAI WS 限额时不自动切换账号
This commit is contained in:
Wesley Liddick 2026-05-27 15:14:28 +08:00 committed by GitHub
commit 2387cf9934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 687 additions and 198 deletions

View File

@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
releaseTurnSlots := func() { releaseAccountSlot := func() {
if currentAccountRelease != nil { if currentAccountRelease != nil {
currentAccountRelease() currentAccountRelease()
currentAccountRelease = nil currentAccountRelease = nil
} }
}
releaseTurnSlots := func() {
releaseAccountSlot()
if currentUserRelease != nil { if currentUserRelease != nil {
currentUserRelease() currentUserRelease()
currentUserRelease = nil currentUserRelease = nil
@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
return return
} }
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
ensureUserSlotHeld := func() bool {
if currentUserRelease != nil {
return true
}
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return false
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return false
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
return true
}
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
firstMessage, firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
) )
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( maxAccountSwitches := h.maxAccountSwitches
ctx, switchCount := 0
apiKey.GroupID, failedAccountIDs := make(map[int64]struct{})
previousResponseID, var lastFailoverErr *service.UpstreamFailoverError
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
account := selection.Account for {
accountMaxConcurrency := account.Concurrency reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx, ctx,
account.ID, apiKey.GroupID,
selection.WaitPlan.MaxConcurrency, previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
) )
if err != nil { if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Warn("openai.websocket_account_select_failed",
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return return
} }
if !fastAcquired { if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return return
} }
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account) account := selection.Account
if err != nil { accountMaxConcurrency := account.Concurrency
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
return }
} accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
reqLog.Debug("openai.websocket_account_selected", if selection.WaitPlan == nil {
zap.Int64("account_id", account.ID), closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return return
} }
if account.Type == service.AccountTypeOAuth { fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
} }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) if !fastAcquired {
inboundEndpoint := GetInboundEndpoint(c) closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) return
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { }
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ accountReleaseFunc = fastReleaseFunc
Result: result, }
APIKey: apiKey, currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
User: apiKey.User, if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
Account: account, reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
Subscription: subscription, }
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
// 应用渠道模型映射到 WebSocket 首条消息 token, _, err := h.gatewayService.GetAccessToken(ctx, account)
wsFirstMessage := firstMessage if err != nil {
if channelMappingWS.Mapped { reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return return
} }
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
releaseAccountSlot()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
h.gatewayService.RecordOpenAIAccountSwitch()
reqLog.Warn("openai.websocket_upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if !ensureUserSlotHeld() {
return
}
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
return return
} }
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
} }
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow() _ = conn.CloseNow()
} }
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
switch failoverErr.StatusCode {
case http.StatusTooManyRequests:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
case http.StatusUnauthorized, http.StatusForbidden:
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
default:
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
}
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil { if conn == nil || decision == nil {
return return

View File

@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
return &account, nil return &account, nil
} }
type openAIWSFailoverHandlerAccountRepoStub struct {
service.AccountRepository
accounts []service.Account
rateLimitedIDs []int64
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
out := make([]service.Account, 0, len(s.accounts))
for _, account := range s.accounts {
if account.Platform == platform && account.IsSchedulable() {
out = append(out, account)
}
}
return out, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
for _, account := range s.accounts {
if account.ID == id {
acc := account
return &acc, nil
}
}
return nil, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
for i := range s.accounts {
if s.accounts[i].ID == id {
reset := resetAt
s.accounts[i].RateLimitResetAt = &reset
break
}
}
return nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct { type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository service.UsageLogRepository
created chan *service.UsageLog created chan *service.UsageLog
@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
return out, nil return out, nil
} }
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
firstHitCh := make(chan []byte, 1)
secondHitCh := make(chan []byte, 1)
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
firstHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
cancelWrite()
}))
defer firstUpstream.Close()
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
secondHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
cancelWrite()
_ = conn.Close(coderws.StatusNormalClosure, "done")
}))
defer secondUpstream.Close()
groupID := int64(4202)
accounts := []service.Account{
{
ID: 9902,
Name: "openai-ws-rate-limited",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "sk-first",
"base_url": firstUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
{
ID: 9903,
Name: "openai-ws-healthy",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 2,
Credentials: map[string]any{
"api_key": "sk-second",
"base_url": secondUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
cfg.Gateway.MaxAccountSwitches = 3
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
rateLimitSvc,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
nil,
nil,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
maxAccountSwitches: 3,
}
apiKey := &service.APIKey{
ID: 1802,
GroupID: &groupID,
User: &service.User{ID: 1702, Status: service.StatusActive},
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
select {
case <-firstHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第一个上游收到首帧超时")
}
select {
case <-secondHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第二个上游收到重放首帧超时")
}
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper() t.Helper()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -2782,6 +2782,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
var dialErr *openAIWSDialError var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(dialErr.ResponseHeaders),
}
} }
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError( return nil, NewOpenAIWSClientCloseError(
@ -2977,6 +2981,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
false, false,
) )
} }
if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
lease.MarkBroken()
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), upstreamMessage...),
ResponseHeaders: cloneHeader(lease.HandshakeHeaders()),
}
}
} }
isTokenEvent := isOpenAIWSTokenEvent(eventType) isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent { if isTokenEvent {

View File

@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
select { select {
case serverErr := <-serverErrCh: case serverErr := <-serverErrCh:
require.Error(t, serverErr) require.Error(t, serverErr)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, serverErr, &failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Len(t, repo.rateLimitCalls, 1) require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):

View File

@ -55,14 +55,18 @@ type RelayExit struct {
} }
type RelayOptions struct { type RelayOptions struct {
WriteTimeout time.Duration WriteTimeout time.Duration
IdleTimeout time.Duration IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string) FirstMessageSent bool
OnTurnComplete func(turn RelayTurnResult) StartClientAfterFirstDownstream bool
OnTrace func(event RelayTraceEvent) OnUsageParseFailure func(eventType string, usageRaw string)
Now func() time.Time OnTurnComplete func(turn RelayTurnResult)
BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error
ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
} }
type RelayTraceEvent struct { type RelayTraceEvent struct {
@ -170,29 +174,47 @@ func Relay(
MessageType: relayMessageTypeString(firstMessageType), MessageType: relayMessageTypeString(firstMessageType),
}) })
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { if options.FirstMessageSent {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed", Stage: "write_first_message_skipped",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
} else {
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream", Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType), MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage), PayloadBytes: len(firstClientMessage),
Error: err.Error(),
}) })
return result, &RelayExit{Stage: "write_upstream", Err: err}
} }
clientToUpstreamFrames.Add(1) clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity() markActivity()
exitCh := make(chan relayExitSignal, 3) exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{} dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) clientReaderStarted := atomic.Bool{}
startClientReader := func() {
if !clientReaderStarted.CompareAndSwap(false, true) {
return
}
go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
}
if !options.StartClientAfterFirstDownstream {
startClientReader()
}
go runUpstreamToClient( go runUpstreamToClient(
relayCtx, relayCtx,
upstreamConn, upstreamConn,
@ -202,6 +224,12 @@ func Relay(
state, state,
options.OnUsageParseFailure, options.OnUsageParseFailure,
options.OnTurnComplete, options.OnTurnComplete,
options.BeforeWriteClient,
func() {
if options.StartClientAfterFirstDownstream {
startClientReader()
}
},
&dropDownstreamWrites, &dropDownstreamWrites,
upstreamToClientFrames, upstreamToClientFrames,
droppedDownstreamFrames, droppedDownstreamFrames,
@ -230,7 +258,9 @@ func Relay(
} else { } else {
relayCancel() relayCancel()
_ = upstreamConn.Close() _ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) if clientReaderStarted.Load() {
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
} }
if hasSecondExit { if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
@ -250,6 +280,14 @@ func Relay(
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load() result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
if firstExit.stage == "read_client" && firstExit.graceful { if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected" stage := "client_disconnected"
exitErr := firstExit.err exitErr := firstExit.err
@ -310,6 +348,14 @@ func Relay(
WroteDownstream: combinedWroteDownstream, WroteDownstream: combinedWroteDownstream,
} }
} }
if options.FirstMessageSent {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete", Stage: "relay_complete",
Graceful: true, Graceful: true,
@ -322,14 +368,20 @@ func Relay(
func runClientToUpstream( func runClientToUpstream(
ctx context.Context, ctx context.Context,
clientConn FrameConn, clientConn FrameConn,
readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error),
writeUpstream func(msgType coderws.MessageType, payload []byte) error, writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(), markActivity func(),
forwardedFrames *atomic.Int64, forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent), onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal, exitCh chan<- relayExitSignal,
) { ) {
if readClientFrame == nil {
readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) {
return conn.ReadFrame(ctx)
}
}
for { for {
msgType, payload, err := clientConn.ReadFrame(ctx) msgType, payload, err := readClientFrame(ctx, clientConn)
if err != nil { if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{ emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed", Stage: "read_client_failed",
@ -368,6 +420,8 @@ func runUpstreamToClient(
state *relayState, state *relayState,
onUsageParseFailure func(eventType string, usageRaw string), onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult), onTurnComplete func(turn RelayTurnResult),
beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error,
afterWriteClient func(),
dropDownstreamWrites *atomic.Bool, dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64, forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64, droppedFrames *atomic.Int64,
@ -395,6 +449,24 @@ func runUpstreamToClient(
return return
} }
markActivity() markActivity()
if beforeWriteClient != nil {
if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "upstream_message_rejected",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{
stage: "upstream_message",
err: err,
wroteDownstream: wroteDownstream,
}
return
}
}
observedEvent := observedUpstreamEvent{} observedEvent := observedUpstreamEvent{}
switch msgType { switch msgType {
case coderws.MessageText: case coderws.MessageText:
@ -438,6 +510,9 @@ func runUpstreamToClient(
return return
} }
wroteDownstream = true wroteDownstream = true
if afterWriteClient != nil {
afterWriteClient()
}
if forwardedFrames != nil { if forwardedFrames != nil {
forwardedFrames.Add(1) forwardedFrames.Add(1)
} }

View File

@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
runClientToUpstream( runClientToUpstream(
context.Background(), context.Background(),
newPassthroughTestFrameConn(nil, true), newPassthroughTestFrameConn(nil, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil }, func(_ coderws.MessageType, _ []byte) error { return nil },
func() {}, func() {},
nil, nil,
@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{ newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true), }, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {}, func() {},
nil, nil,
@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{ newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true), }, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil }, func(_ coderws.MessageType, _ []byte) error { return nil },
func() {}, func() {},
forwarded, forwarded,
@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
nil, nil,
@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
nil, nil,
@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{}, &relayState{},
nil, nil,
nil, nil,
nil,
nil,
drop, drop,
nil, nil,
dropped, dropped,

View File

@ -359,6 +359,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
statusCode, statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
) )
if statusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(handshakeHeaders),
}
}
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
} }
defer func() { defer func() {
@ -456,15 +463,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
cancel() cancel()
}, },
} }
upstreamFirstMessageSent := false
firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage)
cancelFirstWrite()
if firstWriteErr != nil {
return wrapOpenAIWSIngressTurnError(
"write_upstream",
fmt.Errorf("write first upstream websocket request: %w", firstWriteErr),
false,
)
}
upstreamFirstMessageSent = true
readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) {
for {
msgType, payload, readErr := conn.ReadFrame(readCtx)
if readErr != nil {
return msgType, payload, readErr
}
if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
return msgType, payload, nil
}
if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil {
return msgType, payload, writeErr
}
}
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx, Ctx: ctx,
ClientConn: policyClientConn, ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn, UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage, FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{ Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(), WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(), IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText, FirstMessageType: coderws.MessageText,
FirstMessageSent: upstreamFirstMessageSent,
StartClientAfterFirstDownstream: true,
ReadClientFrame: readNextClientFrame,
OnUsageParseFailure: func(eventType string, usageRaw string) { OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough( logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s", "usage_parse_failed event_type=%s usage_raw=%s",
@ -507,6 +545,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
hooks.AfterTurn(turnNo, turnResult, nil) hooks.AfterTurn(turnNo, turnResult, nil)
} }
}, },
BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error {
if msgType != coderws.MessageText || wroteDownstream {
return nil
}
if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" {
return nil
}
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload)
if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
return nil
}
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw)
logOpenAIWSV2Passthrough(
"relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s",
account.ID,
truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen),
)
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), payload...),
ResponseHeaders: cloneHeader(handshakeHeaders),
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) { OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough( logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",