fix: enable account failover for OpenAI WS rate limits

This commit is contained in:
siyuan 2026-05-26 20:07:00 +08:00
parent 9ef144874a
commit 08061717b8
7 changed files with 687 additions and 198 deletions

View File

@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
releaseAccountSlot := func() {
if currentAccountRelease != nil {
currentAccountRelease()
currentAccountRelease = nil
}
}
releaseTurnSlots := func() {
releaseAccountSlot()
if currentUserRelease != nil {
currentUserRelease()
currentUserRelease = nil
@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
return
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
ensureUserSlotHeld := func() bool {
if currentUserRelease != nil {
return true
}
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return false
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return false
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
return true
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
for {
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
reqLog.Warn("openai.websocket_account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
if selection == nil || selection.Account == nil {
if lastFailoverErr != nil {
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
} else {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
}
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
InitialRequestModel: reqModel,
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
if turn == 1 {
return nil
}
if !gjson.ValidBytes(payload) {
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
model := strings.TrimSpace(originalModel)
if model == "" {
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
if model == "" {
model = reqModel
}
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
writeContentModerationWSError(ctx, wsConn, decision)
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
}
return nil
},
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil {
if result == nil || result.ImageCount <= 0 {
return
}
reqLog.Warn("openai.websocket_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
zap.Int("image_count", result.ImageCount),
zap.Error(turnErr),
)
}
if result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
// 应用渠道模型映射到 WebSocket 首条消息
wsFirstMessage := firstMessage
if channelMappingWS.Mapped {
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
releaseAccountSlot()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
return
}
h.gatewayService.RecordOpenAIAccountSwitch()
reqLog.Warn("openai.websocket_upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
if !ensureUserSlotHeld() {
return
}
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
}
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
_ = conn.CloseNow()
}
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
switch failoverErr.StatusCode {
case http.StatusTooManyRequests:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
case http.StatusUnauthorized, http.StatusForbidden:
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
default:
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
}
}
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
if conn == nil || decision == nil {
return

View File

@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
return &account, nil
}
type openAIWSFailoverHandlerAccountRepoStub struct {
service.AccountRepository
accounts []service.Account
rateLimitedIDs []int64
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
out := make([]service.Account, 0, len(s.accounts))
for _, account := range s.accounts {
if account.Platform == platform && account.IsSchedulable() {
out = append(out, account)
}
}
return out, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return s.ListSchedulableByPlatform(ctx, platform)
}
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
for _, account := range s.accounts {
if account.ID == id {
acc := account
return &acc, nil
}
}
return nil, nil
}
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
for i := range s.accounts {
if s.accounts[i].ID == id {
reset := resetAt
s.accounts[i].RateLimitResetAt = &reset
break
}
}
return nil
}
type openAIWSUsageHandlerUsageLogRepoStub struct {
service.UsageLogRepository
created chan *service.UsageLog
@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
return out, nil
}
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
firstHitCh := make(chan []byte, 1)
secondHitCh := make(chan []byte, 1)
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
firstHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
cancelWrite()
}))
defer firstUpstream.Close()
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
if err != nil {
return
}
defer func() { _ = conn.CloseNow() }()
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
_, payload, readErr := conn.Read(readCtx)
cancelRead()
if readErr == nil {
secondHitCh <- payload
}
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
cancelWrite()
_ = conn.Close(coderws.StatusNormalClosure, "done")
}))
defer secondUpstream.Close()
groupID := int64(4202)
accounts := []service.Account{
{
ID: 9902,
Name: "openai-ws-rate-limited",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "sk-first",
"base_url": firstUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
{
ID: 9903,
Name: "openai-ws-healthy",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 2,
Credentials: map[string]any{
"api_key": "sk-second",
"base_url": secondUpstream.URL,
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
},
},
}
cfg := &config.Config{}
cfg.RunMode = config.RunModeSimple
cfg.Default.RateMultiplier = 1
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
cfg.Gateway.MaxAccountSwitches = 3
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
service.NewBillingService(cfg, nil),
rateLimitSvc,
billingCacheSvc,
nil,
&service.DeferredService{},
nil,
nil,
nil,
nil,
nil,
nil,
)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
h := &OpenAIGatewayHandler{
gatewayService: gatewaySvc,
billingCacheService: billingCacheSvc,
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
maxAccountSwitches: 3,
}
apiKey := &service.APIKey{
ID: 1802,
GroupID: &groupID,
User: &service.User{ID: 1702, Status: service.StatusActive},
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
handlerServer := httptest.NewServer(router)
defer handlerServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(
dialCtx,
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
_, event, err := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, err)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
select {
case <-firstHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第一个上游收到首帧超时")
}
select {
case <-secondHitCh:
case <-time.After(3 * time.Second):
t.Fatal("等待第二个上游收到重放首帧超时")
}
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
}
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
t.Helper()
gin.SetMode(gin.TestMode)

View File

@ -2781,6 +2781,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(dialErr.ResponseHeaders),
}
}
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError(
@ -2976,6 +2980,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
false,
)
}
if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
lease.MarkBroken()
return nil, &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), upstreamMessage...),
ResponseHeaders: cloneHeader(lease.HandshakeHeaders()),
}
}
}
isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent {

View File

@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
var failoverErr *UpstreamFailoverError
require.ErrorAs(t, serverErr, &failoverErr)
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
require.Len(t, repo.rateLimitCalls, 1)
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
case <-time.After(5 * time.Second):

View File

@ -55,14 +55,18 @@ type RelayExit struct {
}
type RelayOptions struct {
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
WriteTimeout time.Duration
IdleTimeout time.Duration
UpstreamDrainTimeout time.Duration
FirstMessageType coderws.MessageType
FirstMessageSent bool
StartClientAfterFirstDownstream bool
OnUsageParseFailure func(eventType string, usageRaw string)
OnTurnComplete func(turn RelayTurnResult)
BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error
ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error)
OnTrace func(event RelayTraceEvent)
Now func() time.Time
}
type RelayTraceEvent struct {
@ -170,29 +174,47 @@ func Relay(
MessageType: relayMessageTypeString(firstMessageType),
})
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
if options.FirstMessageSent {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Stage: "write_first_message_skipped",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
} else {
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
result.Duration = nowFn().Sub(startAt)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_failed",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
Error: err.Error(),
})
return result, &RelayExit{Stage: "write_upstream", Err: err}
}
clientToUpstreamFrames.Add(1)
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "write_first_message_ok",
Direction: "client_to_upstream",
MessageType: relayMessageTypeString(firstMessageType),
PayloadBytes: len(firstClientMessage),
})
markActivity()
exitCh := make(chan relayExitSignal, 3)
dropDownstreamWrites := atomic.Bool{}
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
clientReaderStarted := atomic.Bool{}
startClientReader := func() {
if !clientReaderStarted.CompareAndSwap(false, true) {
return
}
go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
}
if !options.StartClientAfterFirstDownstream {
startClientReader()
}
go runUpstreamToClient(
relayCtx,
upstreamConn,
@ -202,6 +224,12 @@ func Relay(
state,
options.OnUsageParseFailure,
options.OnTurnComplete,
options.BeforeWriteClient,
func() {
if options.StartClientAfterFirstDownstream {
startClientReader()
}
},
&dropDownstreamWrites,
upstreamToClientFrames,
droppedDownstreamFrames,
@ -230,7 +258,9 @@ func Relay(
} else {
relayCancel()
_ = upstreamConn.Close()
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
if clientReaderStarted.Load() {
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
}
}
if hasSecondExit {
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
@ -250,6 +280,14 @@ func Relay(
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
if firstExit.stage == "read_client" && firstExit.graceful {
stage := "client_disconnected"
exitErr := firstExit.err
@ -310,6 +348,14 @@ func Relay(
WroteDownstream: combinedWroteDownstream,
}
}
if options.FirstMessageSent {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_client_closed",
Graceful: true,
WroteDownstream: combinedWroteDownstream,
})
return result, nil
}
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "relay_complete",
Graceful: true,
@ -322,14 +368,20 @@ func Relay(
func runClientToUpstream(
ctx context.Context,
clientConn FrameConn,
readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error),
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
markActivity func(),
forwardedFrames *atomic.Int64,
onTrace func(event RelayTraceEvent),
exitCh chan<- relayExitSignal,
) {
if readClientFrame == nil {
readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) {
return conn.ReadFrame(ctx)
}
}
for {
msgType, payload, err := clientConn.ReadFrame(ctx)
msgType, payload, err := readClientFrame(ctx, clientConn)
if err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "read_client_failed",
@ -368,6 +420,8 @@ func runUpstreamToClient(
state *relayState,
onUsageParseFailure func(eventType string, usageRaw string),
onTurnComplete func(turn RelayTurnResult),
beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error,
afterWriteClient func(),
dropDownstreamWrites *atomic.Bool,
forwardedFrames *atomic.Int64,
droppedFrames *atomic.Int64,
@ -395,6 +449,24 @@ func runUpstreamToClient(
return
}
markActivity()
if beforeWriteClient != nil {
if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil {
emitRelayTrace(onTrace, RelayTraceEvent{
Stage: "upstream_message_rejected",
Direction: "upstream_to_client",
MessageType: relayMessageTypeString(msgType),
PayloadBytes: len(payload),
WroteDownstream: wroteDownstream,
Error: err.Error(),
})
exitCh <- relayExitSignal{
stage: "upstream_message",
err: err,
wroteDownstream: wroteDownstream,
}
return
}
}
observedEvent := observedUpstreamEvent{}
switch msgType {
case coderws.MessageText:
@ -438,6 +510,9 @@ func runUpstreamToClient(
return
}
wroteDownstream = true
if afterWriteClient != nil {
afterWriteClient()
}
if forwardedFrames != nil {
forwardedFrames.Add(1)
}

View File

@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
runClientToUpstream(
context.Background(),
newPassthroughTestFrameConn(nil, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
nil,
@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
func() {},
nil,
@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
newPassthroughTestFrameConn([]passthroughTestFrame{
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
}, true),
nil,
func(_ coderws.MessageType, _ []byte) error { return nil },
func() {},
forwarded,
@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{},
nil,
nil,
nil,
nil,
drop,
nil,
nil,
@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{},
nil,
nil,
nil,
nil,
drop,
nil,
nil,
@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
&relayState{},
nil,
nil,
nil,
nil,
drop,
nil,
dropped,

View File

@ -358,6 +358,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
if statusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseHeaders: cloneHeader(handshakeHeaders),
}
}
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
@ -454,15 +461,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
cancel()
},
}
upstreamFirstMessageSent := false
firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage)
cancelFirstWrite()
if firstWriteErr != nil {
return wrapOpenAIWSIngressTurnError(
"write_upstream",
fmt.Errorf("write first upstream websocket request: %w", firstWriteErr),
false,
)
}
upstreamFirstMessageSent = true
readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) {
for {
msgType, payload, readErr := conn.ReadFrame(readCtx)
if readErr != nil {
return msgType, payload, readErr
}
if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
return msgType, payload, nil
}
if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil {
return msgType, payload, writeErr
}
}
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
FirstMessageSent: upstreamFirstMessageSent,
StartClientAfterFirstDownstream: true,
ReadClientFrame: readNextClientFrame,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
@ -505,6 +543,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error {
if msgType != coderws.MessageText || wroteDownstream {
return nil
}
if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" {
return nil
}
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload)
if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
return nil
}
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw)
logOpenAIWSV2Passthrough(
"relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s",
account.ID,
truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen),
)
return &UpstreamFailoverError{
StatusCode: http.StatusTooManyRequests,
ResponseBody: append([]byte(nil), payload...),
ResponseHeaders: cloneHeader(handshakeHeaders),
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",