fix: enable account failover for OpenAI WS rate limits
This commit is contained in:
parent
9ef144874a
commit
08061717b8
@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
releaseAccountSlot := func() {
|
||||
if currentAccountRelease != nil {
|
||||
currentAccountRelease()
|
||||
currentAccountRelease = nil
|
||||
}
|
||||
}
|
||||
releaseTurnSlots := func() {
|
||||
releaseAccountSlot()
|
||||
if currentUserRelease != nil {
|
||||
currentUserRelease()
|
||||
currentUserRelease = nil
|
||||
@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
ensureUserSlotHeld := func() bool {
|
||||
if currentUserRelease != nil {
|
||||
return true
|
||||
}
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||
return false
|
||||
}
|
||||
if !userAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||
return false
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
return true
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
firstMessage,
|
||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||
)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
for {
|
||||
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
reqLog.Warn("openai.websocket_account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if lastFailoverErr != nil {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||
} else {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
}
|
||||
return
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
if selection == nil || selection.Account == nil {
|
||||
if lastFailoverErr != nil {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||
} else {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
}
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
releaseAccountSlot()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||
return
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
reqLog.Warn("openai.websocket_upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if !ensureUserSlotHeld() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||
@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
|
||||
if failoverErr == nil {
|
||||
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
switch failoverErr.StatusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
|
||||
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
|
||||
default:
|
||||
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
}
|
||||
}
|
||||
|
||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||
if conn == nil || decision == nil {
|
||||
return
|
||||
|
||||
@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSFailoverHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
accounts []service.Account
|
||||
rateLimitedIDs []int64
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
out := make([]service.Account, 0, len(s.accounts))
|
||||
for _, account := range s.accounts {
|
||||
if account.Platform == platform && account.IsSchedulable() {
|
||||
out = append(out, account)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
for _, account := range s.accounts {
|
||||
if account.ID == id {
|
||||
acc := account
|
||||
return &acc, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
|
||||
for i := range s.accounts {
|
||||
if s.accounts[i].ID == id {
|
||||
reset := resetAt
|
||||
s.accounts[i].RateLimitResetAt = &reset
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
firstHitCh := make(chan []byte, 1)
|
||||
secondHitCh := make(chan []byte, 1)
|
||||
|
||||
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
firstHitCh <- payload
|
||||
}
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
|
||||
cancelWrite()
|
||||
}))
|
||||
defer firstUpstream.Close()
|
||||
|
||||
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
secondHitCh <- payload
|
||||
}
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
|
||||
cancelWrite()
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
}))
|
||||
defer secondUpstream.Close()
|
||||
|
||||
groupID := int64(4202)
|
||||
accounts := []service.Account{
|
||||
{
|
||||
ID: 9902,
|
||||
Name: "openai-ws-rate-limited",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-first",
|
||||
"base_url": firstUpstream.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 9903,
|
||||
Name: "openai-ws-healthy",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 2,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-second",
|
||||
"base_url": secondUpstream.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
cfg.Gateway.MaxAccountSwitches = 3
|
||||
|
||||
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
|
||||
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
rateLimitSvc,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
maxAccountSwitches: 3,
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1802,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1702, Status: service.StatusActive},
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = clientConn.CloseNow() }()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
|
||||
|
||||
select {
|
||||
case <-firstHitCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待第一个上游收到首帧超时")
|
||||
}
|
||||
select {
|
||||
case <-secondHitCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待第二个上游收到重放首帧超时")
|
||||
}
|
||||
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -2781,6 +2781,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
var dialErr *openAIWSDialError
|
||||
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseHeaders: cloneHeader(dialErr.ResponseHeaders),
|
||||
}
|
||||
}
|
||||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||||
return nil, NewOpenAIWSClientCloseError(
|
||||
@ -2976,6 +2980,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
false,
|
||||
)
|
||||
}
|
||||
if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
|
||||
lease.MarkBroken()
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseBody: append([]byte(nil), upstreamMessage...),
|
||||
ResponseHeaders: cloneHeader(lease.HandshakeHeaders()),
|
||||
}
|
||||
}
|
||||
}
|
||||
isTokenEvent := isOpenAIWSTokenEvent(eventType)
|
||||
if isTokenEvent {
|
||||
|
||||
@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
var failoverErr *UpstreamFailoverError
|
||||
require.ErrorAs(t, serverErr, &failoverErr)
|
||||
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||
require.Len(t, repo.rateLimitCalls, 1)
|
||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||
case <-time.After(5 * time.Second):
|
||||
|
||||
@ -55,14 +55,18 @@ type RelayExit struct {
|
||||
}
|
||||
|
||||
type RelayOptions struct {
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
UpstreamDrainTimeout time.Duration
|
||||
FirstMessageType coderws.MessageType
|
||||
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||
OnTurnComplete func(turn RelayTurnResult)
|
||||
OnTrace func(event RelayTraceEvent)
|
||||
Now func() time.Time
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
UpstreamDrainTimeout time.Duration
|
||||
FirstMessageType coderws.MessageType
|
||||
FirstMessageSent bool
|
||||
StartClientAfterFirstDownstream bool
|
||||
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||
OnTurnComplete func(turn RelayTurnResult)
|
||||
BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error
|
||||
ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error)
|
||||
OnTrace func(event RelayTraceEvent)
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
type RelayTraceEvent struct {
|
||||
@ -170,29 +174,47 @@ func Relay(
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
})
|
||||
|
||||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||
result.Duration = nowFn().Sub(startAt)
|
||||
if options.FirstMessageSent {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_failed",
|
||||
Stage: "write_first_message_skipped",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
})
|
||||
} else {
|
||||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||
result.Duration = nowFn().Sub(startAt)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_failed",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
Error: err.Error(),
|
||||
})
|
||||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_ok",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
Error: err.Error(),
|
||||
})
|
||||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||
}
|
||||
clientToUpstreamFrames.Add(1)
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "write_first_message_ok",
|
||||
Direction: "client_to_upstream",
|
||||
MessageType: relayMessageTypeString(firstMessageType),
|
||||
PayloadBytes: len(firstClientMessage),
|
||||
})
|
||||
markActivity()
|
||||
|
||||
exitCh := make(chan relayExitSignal, 3)
|
||||
dropDownstreamWrites := atomic.Bool{}
|
||||
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||
clientReaderStarted := atomic.Bool{}
|
||||
startClientReader := func() {
|
||||
if !clientReaderStarted.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||
}
|
||||
if !options.StartClientAfterFirstDownstream {
|
||||
startClientReader()
|
||||
}
|
||||
go runUpstreamToClient(
|
||||
relayCtx,
|
||||
upstreamConn,
|
||||
@ -202,6 +224,12 @@ func Relay(
|
||||
state,
|
||||
options.OnUsageParseFailure,
|
||||
options.OnTurnComplete,
|
||||
options.BeforeWriteClient,
|
||||
func() {
|
||||
if options.StartClientAfterFirstDownstream {
|
||||
startClientReader()
|
||||
}
|
||||
},
|
||||
&dropDownstreamWrites,
|
||||
upstreamToClientFrames,
|
||||
droppedDownstreamFrames,
|
||||
@ -230,7 +258,9 @@ func Relay(
|
||||
} else {
|
||||
relayCancel()
|
||||
_ = upstreamConn.Close()
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||
if clientReaderStarted.Load() {
|
||||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||
}
|
||||
}
|
||||
if hasSecondExit {
|
||||
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||
@ -250,6 +280,14 @@ func Relay(
|
||||
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||
if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_client_closed",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||
stage := "client_disconnected"
|
||||
exitErr := firstExit.err
|
||||
@ -310,6 +348,14 @@ func Relay(
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
}
|
||||
}
|
||||
if options.FirstMessageSent {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_client_closed",
|
||||
Graceful: true,
|
||||
WroteDownstream: combinedWroteDownstream,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "relay_complete",
|
||||
Graceful: true,
|
||||
@ -322,14 +368,20 @@ func Relay(
|
||||
func runClientToUpstream(
|
||||
ctx context.Context,
|
||||
clientConn FrameConn,
|
||||
readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error),
|
||||
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||
markActivity func(),
|
||||
forwardedFrames *atomic.Int64,
|
||||
onTrace func(event RelayTraceEvent),
|
||||
exitCh chan<- relayExitSignal,
|
||||
) {
|
||||
if readClientFrame == nil {
|
||||
readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) {
|
||||
return conn.ReadFrame(ctx)
|
||||
}
|
||||
}
|
||||
for {
|
||||
msgType, payload, err := clientConn.ReadFrame(ctx)
|
||||
msgType, payload, err := readClientFrame(ctx, clientConn)
|
||||
if err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "read_client_failed",
|
||||
@ -368,6 +420,8 @@ func runUpstreamToClient(
|
||||
state *relayState,
|
||||
onUsageParseFailure func(eventType string, usageRaw string),
|
||||
onTurnComplete func(turn RelayTurnResult),
|
||||
beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error,
|
||||
afterWriteClient func(),
|
||||
dropDownstreamWrites *atomic.Bool,
|
||||
forwardedFrames *atomic.Int64,
|
||||
droppedFrames *atomic.Int64,
|
||||
@ -395,6 +449,24 @@ func runUpstreamToClient(
|
||||
return
|
||||
}
|
||||
markActivity()
|
||||
if beforeWriteClient != nil {
|
||||
if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil {
|
||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||
Stage: "upstream_message_rejected",
|
||||
Direction: "upstream_to_client",
|
||||
MessageType: relayMessageTypeString(msgType),
|
||||
PayloadBytes: len(payload),
|
||||
WroteDownstream: wroteDownstream,
|
||||
Error: err.Error(),
|
||||
})
|
||||
exitCh <- relayExitSignal{
|
||||
stage: "upstream_message",
|
||||
err: err,
|
||||
wroteDownstream: wroteDownstream,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
observedEvent := observedUpstreamEvent{}
|
||||
switch msgType {
|
||||
case coderws.MessageText:
|
||||
@ -438,6 +510,9 @@ func runUpstreamToClient(
|
||||
return
|
||||
}
|
||||
wroteDownstream = true
|
||||
if afterWriteClient != nil {
|
||||
afterWriteClient()
|
||||
}
|
||||
if forwardedFrames != nil {
|
||||
forwardedFrames.Add(1)
|
||||
}
|
||||
|
||||
@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||
runClientToUpstream(
|
||||
context.Background(),
|
||||
newPassthroughTestFrameConn(nil, true),
|
||||
nil,
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
func() {},
|
||||
nil,
|
||||
@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||
}, true),
|
||||
nil,
|
||||
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
||||
func() {},
|
||||
nil,
|
||||
@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||
}, true),
|
||||
nil,
|
||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||
func() {},
|
||||
forwarded,
|
||||
@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
nil,
|
||||
@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
nil,
|
||||
@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
||||
&relayState{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
drop,
|
||||
nil,
|
||||
dropped,
|
||||
|
||||
@ -358,6 +358,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
statusCode,
|
||||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||||
)
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
}
|
||||
}
|
||||
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||||
}
|
||||
defer func() {
|
||||
@ -454,15 +461,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
cancel()
|
||||
},
|
||||
}
|
||||
upstreamFirstMessageSent := false
|
||||
firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage)
|
||||
cancelFirstWrite()
|
||||
if firstWriteErr != nil {
|
||||
return wrapOpenAIWSIngressTurnError(
|
||||
"write_upstream",
|
||||
fmt.Errorf("write first upstream websocket request: %w", firstWriteErr),
|
||||
false,
|
||||
)
|
||||
}
|
||||
upstreamFirstMessageSent = true
|
||||
|
||||
readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) {
|
||||
for {
|
||||
msgType, payload, readErr := conn.ReadFrame(readCtx)
|
||||
if readErr != nil {
|
||||
return msgType, payload, readErr
|
||||
}
|
||||
if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
return msgType, payload, nil
|
||||
}
|
||||
if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil {
|
||||
return msgType, payload, writeErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||
Ctx: ctx,
|
||||
ClientConn: policyClientConn,
|
||||
UpstreamConn: upstreamFrameConn,
|
||||
FirstClientMessage: firstClientMessage,
|
||||
Options: openaiwsv2.RelayOptions{
|
||||
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||
FirstMessageType: coderws.MessageText,
|
||||
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||
FirstMessageType: coderws.MessageText,
|
||||
FirstMessageSent: upstreamFirstMessageSent,
|
||||
StartClientAfterFirstDownstream: true,
|
||||
ReadClientFrame: readNextClientFrame,
|
||||
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"usage_parse_failed event_type=%s usage_raw=%s",
|
||||
@ -505,6 +543,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
hooks.AfterTurn(turnNo, turnResult, nil)
|
||||
}
|
||||
},
|
||||
BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error {
|
||||
if msgType != coderws.MessageText || wroteDownstream {
|
||||
return nil
|
||||
}
|
||||
if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" {
|
||||
return nil
|
||||
}
|
||||
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload)
|
||||
if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
|
||||
return nil
|
||||
}
|
||||
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s",
|
||||
account.ID,
|
||||
truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen),
|
||||
truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen),
|
||||
)
|
||||
return &UpstreamFailoverError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
ResponseBody: append([]byte(nil), payload...),
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
}
|
||||
},
|
||||
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user