diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 88ece8e7..a51eee86 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { var currentUserRelease func() var currentAccountRelease func() - releaseTurnSlots := func() { + releaseAccountSlot := func() { if currentAccountRelease != nil { currentAccountRelease() currentAccountRelease = nil } + } + releaseTurnSlots := func() { + releaseAccountSlot() if currentUserRelease != nil { currentUserRelease() currentUserRelease = nil @@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { return } currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + ensureUserSlotHeld := func() bool { + if currentUserRelease != nil { + return true + } + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return false + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return false + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + return true + } subscription, _ := middleware2.GetSubscriptionFromContext(c) if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { @@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { firstMessage, openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), ) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( - ctx, - apiKey.GroupID, - previousResponseID, - sessionHash, - reqModel, - nil, - service.OpenAIUpstreamTransportResponsesWebsocketV2, - false, - ) - if err != nil { - reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") - return - } - if selection == nil || selection.Account == nil { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") - return - } + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError - account := selection.Account - accountMaxConcurrency := account.Concurrency - if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { - accountMaxConcurrency = selection.WaitPlan.MaxConcurrency - } - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") - return - } - fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + for { + reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( ctx, - account.ID, - selection.WaitPlan.MaxConcurrency, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) if err != nil { - reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + reqLog.Warn("openai.websocket_account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if lastFailoverErr != nil { + closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr) + } else { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + } return } - if !fastAcquired { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + if selection == nil || selection.Account == nil { + if lastFailoverErr != nil { + closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr) + } else { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + } return } - accountReleaseFunc = fastReleaseFunc - } - currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) - if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - token, _, err := h.gatewayService.GetAccessToken(ctx, account) - if err != nil { - reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") - return - } - - reqLog.Debug("openai.websocket_account_selected", - zap.Int64("account_id", account.ID), - zap.String("account_name", account.Name), - zap.String("schedule_layer", scheduleDecision.Layer), - zap.Int("candidate_count", scheduleDecision.CandidateCount), - ) - - hooks := &service.OpenAIWSIngressHooks{ - InitialRequestModel: reqModel, - BeforeRequest: func(turn int, payload []byte, originalModel string) error { - if turn == 1 { - return nil - } - if !gjson.ValidBytes(payload) { - return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) - } - model := strings.TrimSpace(originalModel) - if model == "" { - model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - } - if model == "" { - model = reqModel - } - if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { - writeContentModerationWSError(ctx, wsConn, decision) - return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) - } - return nil - }, - BeforeTurn: func(turn int) error { - if turn == 1 { - return nil - } - // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 - releaseTurnSlots() - // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 - userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) - if err != nil { - return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) - } - if !userAcquired { - return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) - } - accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) - if err != nil { - if userReleaseFunc != nil { - userReleaseFunc() - } - return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) - } - if !accountAcquired { - if userReleaseFunc != nil { - userReleaseFunc() - } - return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) - } - currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) - currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) - return nil - }, - AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { - releaseTurnSlots() - if turnErr != nil { - if result == nil || result.ImageCount <= 0 { - return - } - reqLog.Warn("openai.websocket_partial_error_with_image_result", - zap.Int64("account_id", account.ID), - zap.Int("image_count", result.ImageCount), - zap.Error(turnErr), - ) - } - if result == nil { + account := selection.Account + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") return } - if account.Type == service.AccountTypeOAuth { - h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) - inboundEndpoint := GetInboundEndpoint(c) - upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { - if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: inboundEndpoint, - UpstreamEndpoint: upstreamEndpoint, - UserAgent: userAgent, - IPAddress: clientIP, - RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), - APIKeyService: h.apiKeyService, - ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), - }); err != nil { - reqLog.Error("openai.websocket_record_usage_failed", - zap.Int64("account_id", account.ID), - zap.String("request_id", result.RequestID), - zap.Error(err), - ) - } - }) - }, - } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } - // 应用渠道模型映射到 WebSocket 首条消息 - wsFirstMessage := firstMessage - if channelMappingWS.Mapped { - wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) - } - - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - closeStatus, closeReason := summarizeWSCloseErrorForLog(err) - reqLog.Warn("openai.websocket_proxy_failed", - zap.Int64("account_id", account.ID), - zap.Error(err), - zap.String("close_status", closeStatus), - zap.String("close_reason", closeReason), - ) - var closeErr *service.OpenAIWSClientCloseError - if errors.As(err, &closeErr) { - closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") return } - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + ) + + hooks := &service.OpenAIWSIngressHooks{ + InitialRequestModel: reqModel, + BeforeRequest: func(turn int, payload []byte, originalModel string) error { + if turn == 1 { + return nil + } + if !gjson.ValidBytes(payload) { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + model := strings.TrimSpace(originalModel) + if model == "" { + model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + } + if model == "" { + model = reqModel + } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) + } + return nil + }, + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil { + if result == nil || result.ImageCount <= 0 { + return + } + reqLog.Warn("openai.websocket_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(turnErr), + ) + } + if result == nil { + return + } + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + // 应用渠道模型映射到 WebSocket 首条消息 + wsFirstMessage := firstMessage + if channelMappingWS.Mapped { + wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + releaseAccountSlot() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + closeOpenAIWSFailoverExhausted(wsConn, failoverErr) + return + } + switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + closeOpenAIWSFailoverExhausted(wsConn, failoverErr) + return + } + h.gatewayService.RecordOpenAIAccountSwitch() + reqLog.Warn("openai.websocket_upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + if !ensureUserSlotHeld() { + return + } + continue + } + + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) return } - reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) + } func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { @@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s _ = conn.CloseNow() } +func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) { + if failoverErr == nil { + closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + switch failoverErr.StatusCode { + case http.StatusTooManyRequests: + closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later") + case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable") + case http.StatusUnauthorized, http.StatusForbidden: + closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed") + default: + closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed") + } +} + func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { if conn == nil || decision == nil { return diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index b304640e..d7d21fac 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in return &account, nil } +type openAIWSFailoverHandlerAccountRepoStub struct { + service.AccountRepository + accounts []service.Account + rateLimitedIDs []int64 +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + out := make([]service.Account, 0, len(s.accounts)) + for _, account := range s.accounts { + if account.Platform == platform && account.IsSchedulable() { + out = append(out, account) + } + } + return out, nil +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) { + for _, account := range s.accounts { + if account.ID == id { + acc := account + return &acc, nil + } + } + return nil, nil +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + s.rateLimitedIDs = append(s.rateLimitedIDs, id) + for i := range s.accounts { + if s.accounts[i].ID == id { + reset := resetAt + s.accounts[i].RateLimitResetAt = &reset + break + } + } + return nil +} + type openAIWSUsageHandlerUsageLogRepoStub struct { service.UsageLogRepository created chan *service.UsageLog @@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont return out, nil } +func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + firstHitCh := make(chan []byte, 1) + secondHitCh := make(chan []byte, 1) + + firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + return + } + defer func() { _ = conn.CloseNow() }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + _, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr == nil { + firstHitCh <- payload + } + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + _ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`)) + cancelWrite() + })) + defer firstUpstream.Close() + + secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + return + } + defer func() { _ = conn.CloseNow() }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + _, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr == nil { + secondHitCh <- payload + } + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + _ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`)) + cancelWrite() + _ = conn.Close(coderws.StatusNormalClosure, "done") + })) + defer secondUpstream.Close() + + groupID := int64(4202) + accounts := []service.Account{ + { + ID: 9902, + Name: "openai-ws-rate-limited", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "sk-first", + "base_url": firstUpstream.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + }, + { + ID: 9903, + Name: "openai-ws-healthy", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 2, + Credentials: map[string]any{ + "api_key": "sk-second", + "base_url": secondUpstream.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + }, + } + + cfg := &config.Config{} + cfg.RunMode = config.RunModeSimple + cfg.Default.RateMultiplier = 1 + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + cfg.Gateway.MaxAccountSwitches = 3 + + accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts} + rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) + gatewaySvc := service.NewOpenAIGatewayService( + accountRepo, + nil, + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + service.NewBillingService(cfg, nil), + rateLimitSvc, + billingCacheSvc, + nil, + &service.DeferredService{}, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + h := &OpenAIGatewayHandler{ + gatewayService: gatewaySvc, + billingCacheService: billingCacheSvc, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + maxAccountSwitches: 3, + } + + apiKey := &service.APIKey{ + ID: 1802, + GroupID: &groupID, + User: &service.User{ID: 1702, Status: service.StatusActive}, + Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1}) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + handlerServer := httptest.NewServer(router) + defer handlerServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses", + &coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover}, + ) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second) + _, event, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String()) + + select { + case <-firstHitCh: + case <-time.After(3 * time.Second): + t.Fatal("等待第一个上游收到首帧超时") + } + select { + case <-secondHitCh: + case <-time.After(3 * time.Second): + t.Fatal("等待第二个上游收到重放首帧超时") + } + require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs) +} + func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { t.Helper() gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 5edf4db9..26a551bd 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2781,6 +2781,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( var dialErr *openAIWSDialError if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseHeaders: cloneHeader(dialErr.ResponseHeaders), + } } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( @@ -2976,6 +2980,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( false, ) } + if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) { + lease.MarkBroken() + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseBody: append([]byte(nil), upstreamMessage...), + ResponseHeaders: cloneHeader(lease.HandshakeHeaders()), + } + } } isTokenEvent := isOpenAIWSTokenEvent(eventType) if isTokenEvent { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 4ee85a3a..a3673d74 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL select { case serverErr := <-serverErrCh: require.Error(t, serverErr) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, serverErr, &failoverErr) + require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode) require.Len(t, repo.rateLimitCalls, 1) require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) case <-time.After(5 * time.Second): diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go index 2b7e2add..35c7569d 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -55,14 +55,18 @@ type RelayExit struct { } type RelayOptions struct { - WriteTimeout time.Duration - IdleTimeout time.Duration - UpstreamDrainTimeout time.Duration - FirstMessageType coderws.MessageType - OnUsageParseFailure func(eventType string, usageRaw string) - OnTurnComplete func(turn RelayTurnResult) - OnTrace func(event RelayTraceEvent) - Now func() time.Time + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + FirstMessageSent bool + StartClientAfterFirstDownstream bool + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error + ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error) + OnTrace func(event RelayTraceEvent) + Now func() time.Time } type RelayTraceEvent struct { @@ -170,29 +174,47 @@ func Relay( MessageType: relayMessageTypeString(firstMessageType), }) - if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { - result.Duration = nowFn().Sub(startAt) + if options.FirstMessageSent { emitRelayTrace(onTrace, RelayTraceEvent{ - Stage: "write_first_message_failed", + Stage: "write_first_message_skipped", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + } else { + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", Direction: "client_to_upstream", MessageType: relayMessageTypeString(firstMessageType), PayloadBytes: len(firstClientMessage), - Error: err.Error(), }) - return result, &RelayExit{Stage: "write_upstream", Err: err} } clientToUpstreamFrames.Add(1) - emitRelayTrace(onTrace, RelayTraceEvent{ - Stage: "write_first_message_ok", - Direction: "client_to_upstream", - MessageType: relayMessageTypeString(firstMessageType), - PayloadBytes: len(firstClientMessage), - }) markActivity() exitCh := make(chan relayExitSignal, 3) dropDownstreamWrites := atomic.Bool{} - go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + clientReaderStarted := atomic.Bool{} + startClientReader := func() { + if !clientReaderStarted.CompareAndSwap(false, true) { + return + } + go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + } + if !options.StartClientAfterFirstDownstream { + startClientReader() + } go runUpstreamToClient( relayCtx, upstreamConn, @@ -202,6 +224,12 @@ func Relay( state, options.OnUsageParseFailure, options.OnTurnComplete, + options.BeforeWriteClient, + func() { + if options.StartClientAfterFirstDownstream { + startClientReader() + } + }, &dropDownstreamWrites, upstreamToClientFrames, droppedDownstreamFrames, @@ -230,7 +258,9 @@ func Relay( } else { relayCancel() _ = upstreamConn.Close() - secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + if clientReaderStarted.Load() { + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } } if hasSecondExit { combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream @@ -250,6 +280,14 @@ func Relay( result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() result.UpstreamToClientFrames = upstreamToClientFrames.Load() result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_client_closed", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + return result, nil + } if firstExit.stage == "read_client" && firstExit.graceful { stage := "client_disconnected" exitErr := firstExit.err @@ -310,6 +348,14 @@ func Relay( WroteDownstream: combinedWroteDownstream, } } + if options.FirstMessageSent { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_client_closed", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + return result, nil + } emitRelayTrace(onTrace, RelayTraceEvent{ Stage: "relay_complete", Graceful: true, @@ -322,14 +368,20 @@ func Relay( func runClientToUpstream( ctx context.Context, clientConn FrameConn, + readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error), writeUpstream func(msgType coderws.MessageType, payload []byte) error, markActivity func(), forwardedFrames *atomic.Int64, onTrace func(event RelayTraceEvent), exitCh chan<- relayExitSignal, ) { + if readClientFrame == nil { + readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) { + return conn.ReadFrame(ctx) + } + } for { - msgType, payload, err := clientConn.ReadFrame(ctx) + msgType, payload, err := readClientFrame(ctx, clientConn) if err != nil { emitRelayTrace(onTrace, RelayTraceEvent{ Stage: "read_client_failed", @@ -368,6 +420,8 @@ func runUpstreamToClient( state *relayState, onUsageParseFailure func(eventType string, usageRaw string), onTurnComplete func(turn RelayTurnResult), + beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error, + afterWriteClient func(), dropDownstreamWrites *atomic.Bool, forwardedFrames *atomic.Int64, droppedFrames *atomic.Int64, @@ -395,6 +449,24 @@ func runUpstreamToClient( return } markActivity() + if beforeWriteClient != nil { + if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "upstream_message_rejected", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{ + stage: "upstream_message", + err: err, + wroteDownstream: wroteDownstream, + } + return + } + } observedEvent := observedUpstreamEvent{} switch msgType { case coderws.MessageText: @@ -438,6 +510,9 @@ func runUpstreamToClient( return } wroteDownstream = true + if afterWriteClient != nil { + afterWriteClient() + } if forwardedFrames != nil { forwardedFrames.Add(1) } diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go index 123e10ce..52104482 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { runClientToUpstream( context.Background(), newPassthroughTestFrameConn(nil, true), + nil, func(_ coderws.MessageType, _ []byte) error { return nil }, func() {}, nil, @@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { newPassthroughTestFrameConn([]passthroughTestFrame{ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, }, true), + nil, func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, func() {}, nil, @@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { newPassthroughTestFrameConn([]passthroughTestFrame{ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, }, true), + nil, func(_ coderws.MessageType, _ []byte) error { return nil }, func() {}, forwarded, @@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, nil, @@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, nil, @@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, dropped, diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 0a89e2dd..406e2849 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -358,6 +358,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( statusCode, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), ) + if statusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + return &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseHeaders: cloneHeader(handshakeHeaders), + } + } return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) } defer func() { @@ -454,15 +461,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( cancel() }, } + upstreamFirstMessageSent := false + firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage) + cancelFirstWrite() + if firstWriteErr != nil { + return wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write first upstream websocket request: %w", firstWriteErr), + false, + ) + } + upstreamFirstMessageSent = true + + readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) { + for { + msgType, payload, readErr := conn.ReadFrame(readCtx) + if readErr != nil { + return msgType, payload, readErr + } + if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + return msgType, payload, nil + } + if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil { + return msgType, payload, writeErr + } + } + } + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ Ctx: ctx, ClientConn: policyClientConn, UpstreamConn: upstreamFrameConn, FirstClientMessage: firstClientMessage, Options: openaiwsv2.RelayOptions{ - WriteTimeout: s.openAIWSWriteTimeout(), - IdleTimeout: s.openAIWSPassthroughIdleTimeout(), - FirstMessageType: coderws.MessageText, + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: coderws.MessageText, + FirstMessageSent: upstreamFirstMessageSent, + StartClientAfterFirstDownstream: true, + ReadClientFrame: readNextClientFrame, OnUsageParseFailure: func(eventType string, usageRaw string) { logOpenAIWSV2Passthrough( "usage_parse_failed event_type=%s usage_raw=%s", @@ -505,6 +543,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( hooks.AfterTurn(turnNo, turnResult, nil) } }, + BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error { + if msgType != coderws.MessageText || wroteDownstream { + return nil + } + if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" { + return nil + } + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload) + if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) { + return nil + } + s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSV2Passthrough( + "relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s", + account.ID, + truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen), + ) + return &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseBody: append([]byte(nil), payload...), + ResponseHeaders: cloneHeader(handshakeHeaders), + } + }, OnTrace: func(event openaiwsv2.RelayTraceEvent) { logOpenAIWSV2Passthrough( "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",