Merge pull request #2799 from siyuan-123/fix/ws-rate-limit-failover
修复 OpenAI WS 限额时不自动切换账号
This commit is contained in:
commit
2387cf9934
@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
|
|
||||||
var currentUserRelease func()
|
var currentUserRelease func()
|
||||||
var currentAccountRelease func()
|
var currentAccountRelease func()
|
||||||
releaseTurnSlots := func() {
|
releaseAccountSlot := func() {
|
||||||
if currentAccountRelease != nil {
|
if currentAccountRelease != nil {
|
||||||
currentAccountRelease()
|
currentAccountRelease()
|
||||||
currentAccountRelease = nil
|
currentAccountRelease = nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
releaseTurnSlots := func() {
|
||||||
|
releaseAccountSlot()
|
||||||
if currentUserRelease != nil {
|
if currentUserRelease != nil {
|
||||||
currentUserRelease()
|
currentUserRelease()
|
||||||
currentUserRelease = nil
|
currentUserRelease = nil
|
||||||
@ -1233,6 +1236,23 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||||
|
ensureUserSlotHeld := func() bool {
|
||||||
|
if currentUserRelease != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !userAcquired {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||||
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
firstMessage,
|
firstMessage,
|
||||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||||
)
|
)
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
ctx,
|
switchCount := 0
|
||||||
apiKey.GroupID,
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
previousResponseID,
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
sessionHash,
|
|
||||||
reqModel,
|
|
||||||
nil,
|
|
||||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if selection == nil || selection.Account == nil {
|
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
account := selection.Account
|
for {
|
||||||
accountMaxConcurrency := account.Concurrency
|
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
|
||||||
}
|
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
|
||||||
if !selection.Acquired {
|
|
||||||
if selection.WaitPlan == nil {
|
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
|
||||||
ctx,
|
ctx,
|
||||||
account.ID,
|
apiKey.GroupID,
|
||||||
selection.WaitPlan.MaxConcurrency,
|
previousResponseID,
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||||
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Warn("openai.websocket_account_select_failed",
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||||
|
} else {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !fastAcquired {
|
if selection == nil || selection.Account == nil {
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
if lastFailoverErr != nil {
|
||||||
|
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||||
|
} else {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accountReleaseFunc = fastReleaseFunc
|
|
||||||
}
|
|
||||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
|
||||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
|
||||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
account := selection.Account
|
||||||
if err != nil {
|
accountMaxConcurrency := account.Concurrency
|
||||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||||
return
|
}
|
||||||
}
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
|
if !selection.Acquired {
|
||||||
reqLog.Debug("openai.websocket_account_selected",
|
if selection.WaitPlan == nil {
|
||||||
zap.Int64("account_id", account.ID),
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||||
zap.String("account_name", account.Name),
|
|
||||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
|
||||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
|
||||||
)
|
|
||||||
|
|
||||||
hooks := &service.OpenAIWSIngressHooks{
|
|
||||||
InitialRequestModel: reqModel,
|
|
||||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
|
||||||
if turn == 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !gjson.ValidBytes(payload) {
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
|
||||||
}
|
|
||||||
model := strings.TrimSpace(originalModel)
|
|
||||||
if model == "" {
|
|
||||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
|
||||||
}
|
|
||||||
if model == "" {
|
|
||||||
model = reqModel
|
|
||||||
}
|
|
||||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
|
||||||
writeContentModerationWSError(ctx, wsConn, decision)
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
BeforeTurn: func(turn int) error {
|
|
||||||
if turn == 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
|
||||||
releaseTurnSlots()
|
|
||||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
|
||||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
|
||||||
}
|
|
||||||
if !userAcquired {
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
|
||||||
}
|
|
||||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
|
||||||
if err != nil {
|
|
||||||
if userReleaseFunc != nil {
|
|
||||||
userReleaseFunc()
|
|
||||||
}
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
|
||||||
}
|
|
||||||
if !accountAcquired {
|
|
||||||
if userReleaseFunc != nil {
|
|
||||||
userReleaseFunc()
|
|
||||||
}
|
|
||||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
|
||||||
}
|
|
||||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
|
||||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
|
||||||
releaseTurnSlots()
|
|
||||||
if turnErr != nil {
|
|
||||||
if result == nil || result.ImageCount <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("image_count", result.ImageCount),
|
|
||||||
zap.Error(turnErr),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if result == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if account.Type == service.AccountTypeOAuth {
|
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
ctx,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
if !fastAcquired {
|
||||||
inboundEndpoint := GetInboundEndpoint(c)
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
return
|
||||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
}
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
accountReleaseFunc = fastReleaseFunc
|
||||||
Result: result,
|
}
|
||||||
APIKey: apiKey,
|
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||||
User: apiKey.User,
|
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
Account: account,
|
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
Subscription: subscription,
|
}
|
||||||
InboundEndpoint: inboundEndpoint,
|
|
||||||
UpstreamEndpoint: upstreamEndpoint,
|
|
||||||
UserAgent: userAgent,
|
|
||||||
IPAddress: clientIP,
|
|
||||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
|
||||||
APIKeyService: h.apiKeyService,
|
|
||||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
|
||||||
}); err != nil {
|
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.String("request_id", result.RequestID),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用渠道模型映射到 WebSocket 首条消息
|
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||||
wsFirstMessage := firstMessage
|
if err != nil {
|
||||||
if channelMappingWS.Mapped {
|
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
|
||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
|
||||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
|
||||||
reqLog.Warn("openai.websocket_proxy_failed",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Error(err),
|
|
||||||
zap.String("close_status", closeStatus),
|
|
||||||
zap.String("close_reason", closeReason),
|
|
||||||
)
|
|
||||||
var closeErr *service.OpenAIWSClientCloseError
|
|
||||||
if errors.As(err, &closeErr) {
|
|
||||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
|
||||||
|
reqLog.Debug("openai.websocket_account_selected",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("account_name", account.Name),
|
||||||
|
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||||
|
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks := &service.OpenAIWSIngressHooks{
|
||||||
|
InitialRequestModel: reqModel,
|
||||||
|
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||||
|
if turn == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(payload) {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||||
|
}
|
||||||
|
model := strings.TrimSpace(originalModel)
|
||||||
|
if model == "" {
|
||||||
|
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
model = reqModel
|
||||||
|
}
|
||||||
|
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||||
|
writeContentModerationWSError(ctx, wsConn, decision)
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
BeforeTurn: func(turn int) error {
|
||||||
|
if turn == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||||
|
releaseTurnSlots()
|
||||||
|
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||||
|
}
|
||||||
|
if !userAcquired {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||||
|
}
|
||||||
|
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||||
|
if err != nil {
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
userReleaseFunc()
|
||||||
|
}
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||||
|
}
|
||||||
|
if !accountAcquired {
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
userReleaseFunc()
|
||||||
|
}
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||||
|
}
|
||||||
|
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||||
|
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||||
|
releaseTurnSlots()
|
||||||
|
if turnErr != nil {
|
||||||
|
if result == nil || result.ImageCount <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("image_count", result.ImageCount),
|
||||||
|
zap.Error(turnErr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if account.Type == service.AccountTypeOAuth {
|
||||||
|
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: inboundEndpoint,
|
||||||
|
UpstreamEndpoint: upstreamEndpoint,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||||
|
}); err != nil {
|
||||||
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("request_id", result.RequestID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用渠道模型映射到 WebSocket 首条消息
|
||||||
|
wsFirstMessage := firstMessage
|
||||||
|
if channelMappingWS.Mapped {
|
||||||
|
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
releaseAccountSlot()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||||
|
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
reqLog.Warn("openai.websocket_upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
if !ensureUserSlotHeld() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||||
|
reqLog.Warn("openai.websocket_proxy_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("close_status", closeStatus),
|
||||||
|
zap.String("close_reason", closeReason),
|
||||||
|
)
|
||||||
|
var closeErr *service.OpenAIWSClientCloseError
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||||
@ -1800,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
|||||||
_ = conn.CloseNow()
|
_ = conn.CloseNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
|
||||||
|
if failoverErr == nil {
|
||||||
|
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch failoverErr.StatusCode {
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
|
||||||
|
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||||
|
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
|
||||||
|
case http.StatusUnauthorized, http.StatusForbidden:
|
||||||
|
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
|
||||||
|
default:
|
||||||
|
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||||
if conn == nil || decision == nil {
|
if conn == nil || decision == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1075,6 +1075,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
|
|||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIWSFailoverHandlerAccountRepoStub struct {
|
||||||
|
service.AccountRepository
|
||||||
|
accounts []service.Account
|
||||||
|
rateLimitedIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
out := make([]service.Account, 0, len(s.accounts))
|
||||||
|
for _, account := range s.accounts {
|
||||||
|
if account.Platform == platform && account.IsSchedulable() {
|
||||||
|
out = append(out, account)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||||
|
return s.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return s.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
for _, account := range s.accounts {
|
||||||
|
if account.ID == id {
|
||||||
|
acc := account
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
|
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
|
||||||
|
for i := range s.accounts {
|
||||||
|
if s.accounts[i].ID == id {
|
||||||
|
reset := resetAt
|
||||||
|
s.accounts[i].RateLimitResetAt = &reset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||||
service.UsageLogRepository
|
service.UsageLogRepository
|
||||||
created chan *service.UsageLog
|
created chan *service.UsageLog
|
||||||
@ -1107,6 +1153,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
firstHitCh := make(chan []byte, 1)
|
||||||
|
secondHitCh := make(chan []byte, 1)
|
||||||
|
|
||||||
|
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.CloseNow() }()
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
_, payload, readErr := conn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
if readErr == nil {
|
||||||
|
firstHitCh <- payload
|
||||||
|
}
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
|
||||||
|
cancelWrite()
|
||||||
|
}))
|
||||||
|
defer firstUpstream.Close()
|
||||||
|
|
||||||
|
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.CloseNow() }()
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
_, payload, readErr := conn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
if readErr == nil {
|
||||||
|
secondHitCh <- payload
|
||||||
|
}
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
|
||||||
|
cancelWrite()
|
||||||
|
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
}))
|
||||||
|
defer secondUpstream.Close()
|
||||||
|
|
||||||
|
groupID := int64(4202)
|
||||||
|
accounts := []service.Account{
|
||||||
|
{
|
||||||
|
ID: 9902,
|
||||||
|
Name: "openai-ws-rate-limited",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-first",
|
||||||
|
"base_url": firstUpstream.URL,
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 9903,
|
||||||
|
Name: "openai-ws-healthy",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 2,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-second",
|
||||||
|
"base_url": secondUpstream.URL,
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.RunMode = config.RunModeSimple
|
||||||
|
cfg.Default.RateMultiplier = 1
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.MaxAccountSwitches = 3
|
||||||
|
|
||||||
|
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
|
||||||
|
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
|
||||||
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||||
|
gatewaySvc := service.NewOpenAIGatewayService(
|
||||||
|
accountRepo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
service.NewBillingService(cfg, nil),
|
||||||
|
rateLimitSvc,
|
||||||
|
billingCacheSvc,
|
||||||
|
nil,
|
||||||
|
&service.DeferredService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache := &concurrencyCacheMock{
|
||||||
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := &OpenAIGatewayHandler{
|
||||||
|
gatewayService: gatewaySvc,
|
||||||
|
billingCacheService: billingCacheSvc,
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||||
|
maxAccountSwitches: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 1802,
|
||||||
|
GroupID: &groupID,
|
||||||
|
User: &service.User{ID: 1702, Status: service.StatusActive},
|
||||||
|
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
|
||||||
|
}
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||||
|
handlerServer := httptest.NewServer(router)
|
||||||
|
defer handlerServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(
|
||||||
|
dialCtx,
|
||||||
|
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||||
|
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
|
||||||
|
)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = clientConn.CloseNow() }()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
_, event, err := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||||
|
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-firstHitCh:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("等待第一个上游收到首帧超时")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-secondHitCh:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("等待第二个上游收到重放首帧超时")
|
||||||
|
}
|
||||||
|
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|||||||
@ -2782,6 +2782,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
var dialErr *openAIWSDialError
|
var dialErr *openAIWSDialError
|
||||||
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
|
||||||
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
ResponseHeaders: cloneHeader(dialErr.ResponseHeaders),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
|
||||||
return nil, NewOpenAIWSClientCloseError(
|
return nil, NewOpenAIWSClientCloseError(
|
||||||
@ -2977,6 +2981,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
|
||||||
|
lease.MarkBroken()
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
ResponseBody: append([]byte(nil), upstreamMessage...),
|
||||||
|
ResponseHeaders: cloneHeader(lease.HandshakeHeaders()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
isTokenEvent := isOpenAIWSTokenEvent(eventType)
|
isTokenEvent := isOpenAIWSTokenEvent(eventType)
|
||||||
if isTokenEvent {
|
if isTokenEvent {
|
||||||
|
|||||||
@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL
|
|||||||
select {
|
select {
|
||||||
case serverErr := <-serverErrCh:
|
case serverErr := <-serverErrCh:
|
||||||
require.Error(t, serverErr)
|
require.Error(t, serverErr)
|
||||||
|
var failoverErr *UpstreamFailoverError
|
||||||
|
require.ErrorAs(t, serverErr, &failoverErr)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode)
|
||||||
require.Len(t, repo.rateLimitCalls, 1)
|
require.Len(t, repo.rateLimitCalls, 1)
|
||||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
|
|||||||
@ -55,14 +55,18 @@ type RelayExit struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RelayOptions struct {
|
type RelayOptions struct {
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
UpstreamDrainTimeout time.Duration
|
UpstreamDrainTimeout time.Duration
|
||||||
FirstMessageType coderws.MessageType
|
FirstMessageType coderws.MessageType
|
||||||
OnUsageParseFailure func(eventType string, usageRaw string)
|
FirstMessageSent bool
|
||||||
OnTurnComplete func(turn RelayTurnResult)
|
StartClientAfterFirstDownstream bool
|
||||||
OnTrace func(event RelayTraceEvent)
|
OnUsageParseFailure func(eventType string, usageRaw string)
|
||||||
Now func() time.Time
|
OnTurnComplete func(turn RelayTurnResult)
|
||||||
|
BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error
|
||||||
|
ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error)
|
||||||
|
OnTrace func(event RelayTraceEvent)
|
||||||
|
Now func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelayTraceEvent struct {
|
type RelayTraceEvent struct {
|
||||||
@ -170,29 +174,47 @@ func Relay(
|
|||||||
MessageType: relayMessageTypeString(firstMessageType),
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
if options.FirstMessageSent {
|
||||||
result.Duration = nowFn().Sub(startAt)
|
|
||||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
Stage: "write_first_message_failed",
|
Stage: "write_first_message_skipped",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
||||||
|
result.Duration = nowFn().Sub(startAt)
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_failed",
|
||||||
|
Direction: "client_to_upstream",
|
||||||
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
|
PayloadBytes: len(firstClientMessage),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
||||||
|
}
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "write_first_message_ok",
|
||||||
Direction: "client_to_upstream",
|
Direction: "client_to_upstream",
|
||||||
MessageType: relayMessageTypeString(firstMessageType),
|
MessageType: relayMessageTypeString(firstMessageType),
|
||||||
PayloadBytes: len(firstClientMessage),
|
PayloadBytes: len(firstClientMessage),
|
||||||
Error: err.Error(),
|
|
||||||
})
|
})
|
||||||
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
|
||||||
}
|
}
|
||||||
clientToUpstreamFrames.Add(1)
|
clientToUpstreamFrames.Add(1)
|
||||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
|
||||||
Stage: "write_first_message_ok",
|
|
||||||
Direction: "client_to_upstream",
|
|
||||||
MessageType: relayMessageTypeString(firstMessageType),
|
|
||||||
PayloadBytes: len(firstClientMessage),
|
|
||||||
})
|
|
||||||
markActivity()
|
markActivity()
|
||||||
|
|
||||||
exitCh := make(chan relayExitSignal, 3)
|
exitCh := make(chan relayExitSignal, 3)
|
||||||
dropDownstreamWrites := atomic.Bool{}
|
dropDownstreamWrites := atomic.Bool{}
|
||||||
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
clientReaderStarted := atomic.Bool{}
|
||||||
|
startClientReader := func() {
|
||||||
|
if !clientReaderStarted.CompareAndSwap(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
||||||
|
}
|
||||||
|
if !options.StartClientAfterFirstDownstream {
|
||||||
|
startClientReader()
|
||||||
|
}
|
||||||
go runUpstreamToClient(
|
go runUpstreamToClient(
|
||||||
relayCtx,
|
relayCtx,
|
||||||
upstreamConn,
|
upstreamConn,
|
||||||
@ -202,6 +224,12 @@ func Relay(
|
|||||||
state,
|
state,
|
||||||
options.OnUsageParseFailure,
|
options.OnUsageParseFailure,
|
||||||
options.OnTurnComplete,
|
options.OnTurnComplete,
|
||||||
|
options.BeforeWriteClient,
|
||||||
|
func() {
|
||||||
|
if options.StartClientAfterFirstDownstream {
|
||||||
|
startClientReader()
|
||||||
|
}
|
||||||
|
},
|
||||||
&dropDownstreamWrites,
|
&dropDownstreamWrites,
|
||||||
upstreamToClientFrames,
|
upstreamToClientFrames,
|
||||||
droppedDownstreamFrames,
|
droppedDownstreamFrames,
|
||||||
@ -230,7 +258,9 @@ func Relay(
|
|||||||
} else {
|
} else {
|
||||||
relayCancel()
|
relayCancel()
|
||||||
_ = upstreamConn.Close()
|
_ = upstreamConn.Close()
|
||||||
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
if clientReaderStarted.Load() {
|
||||||
|
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if hasSecondExit {
|
if hasSecondExit {
|
||||||
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
||||||
@ -250,6 +280,14 @@ func Relay(
|
|||||||
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
||||||
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
||||||
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
||||||
|
if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_client_closed",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
if firstExit.stage == "read_client" && firstExit.graceful {
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
||||||
stage := "client_disconnected"
|
stage := "client_disconnected"
|
||||||
exitErr := firstExit.err
|
exitErr := firstExit.err
|
||||||
@ -310,6 +348,14 @@ func Relay(
|
|||||||
WroteDownstream: combinedWroteDownstream,
|
WroteDownstream: combinedWroteDownstream,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if options.FirstMessageSent {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "relay_client_closed",
|
||||||
|
Graceful: true,
|
||||||
|
WroteDownstream: combinedWroteDownstream,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
Stage: "relay_complete",
|
Stage: "relay_complete",
|
||||||
Graceful: true,
|
Graceful: true,
|
||||||
@ -322,14 +368,20 @@ func Relay(
|
|||||||
func runClientToUpstream(
|
func runClientToUpstream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
clientConn FrameConn,
|
clientConn FrameConn,
|
||||||
|
readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error),
|
||||||
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
||||||
markActivity func(),
|
markActivity func(),
|
||||||
forwardedFrames *atomic.Int64,
|
forwardedFrames *atomic.Int64,
|
||||||
onTrace func(event RelayTraceEvent),
|
onTrace func(event RelayTraceEvent),
|
||||||
exitCh chan<- relayExitSignal,
|
exitCh chan<- relayExitSignal,
|
||||||
) {
|
) {
|
||||||
|
if readClientFrame == nil {
|
||||||
|
readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) {
|
||||||
|
return conn.ReadFrame(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
msgType, payload, err := clientConn.ReadFrame(ctx)
|
msgType, payload, err := readClientFrame(ctx, clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
emitRelayTrace(onTrace, RelayTraceEvent{
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
Stage: "read_client_failed",
|
Stage: "read_client_failed",
|
||||||
@ -368,6 +420,8 @@ func runUpstreamToClient(
|
|||||||
state *relayState,
|
state *relayState,
|
||||||
onUsageParseFailure func(eventType string, usageRaw string),
|
onUsageParseFailure func(eventType string, usageRaw string),
|
||||||
onTurnComplete func(turn RelayTurnResult),
|
onTurnComplete func(turn RelayTurnResult),
|
||||||
|
beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error,
|
||||||
|
afterWriteClient func(),
|
||||||
dropDownstreamWrites *atomic.Bool,
|
dropDownstreamWrites *atomic.Bool,
|
||||||
forwardedFrames *atomic.Int64,
|
forwardedFrames *atomic.Int64,
|
||||||
droppedFrames *atomic.Int64,
|
droppedFrames *atomic.Int64,
|
||||||
@ -395,6 +449,24 @@ func runUpstreamToClient(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
markActivity()
|
markActivity()
|
||||||
|
if beforeWriteClient != nil {
|
||||||
|
if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil {
|
||||||
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
||||||
|
Stage: "upstream_message_rejected",
|
||||||
|
Direction: "upstream_to_client",
|
||||||
|
MessageType: relayMessageTypeString(msgType),
|
||||||
|
PayloadBytes: len(payload),
|
||||||
|
WroteDownstream: wroteDownstream,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
exitCh <- relayExitSignal{
|
||||||
|
stage: "upstream_message",
|
||||||
|
err: err,
|
||||||
|
wroteDownstream: wroteDownstream,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
observedEvent := observedUpstreamEvent{}
|
observedEvent := observedUpstreamEvent{}
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case coderws.MessageText:
|
case coderws.MessageText:
|
||||||
@ -438,6 +510,9 @@ func runUpstreamToClient(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
wroteDownstream = true
|
wroteDownstream = true
|
||||||
|
if afterWriteClient != nil {
|
||||||
|
afterWriteClient()
|
||||||
|
}
|
||||||
if forwardedFrames != nil {
|
if forwardedFrames != nil {
|
||||||
forwardedFrames.Add(1)
|
forwardedFrames.Add(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
|||||||
runClientToUpstream(
|
runClientToUpstream(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
newPassthroughTestFrameConn(nil, true),
|
newPassthroughTestFrameConn(nil, true),
|
||||||
|
nil,
|
||||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
func() {},
|
func() {},
|
||||||
nil,
|
nil,
|
||||||
@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
|||||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
}, true),
|
}, true),
|
||||||
|
nil,
|
||||||
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
|
||||||
func() {},
|
func() {},
|
||||||
nil,
|
nil,
|
||||||
@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
|
|||||||
newPassthroughTestFrameConn([]passthroughTestFrame{
|
newPassthroughTestFrameConn([]passthroughTestFrame{
|
||||||
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
{msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
|
||||||
}, true),
|
}, true),
|
||||||
|
nil,
|
||||||
func(_ coderws.MessageType, _ []byte) error { return nil },
|
func(_ coderws.MessageType, _ []byte) error { return nil },
|
||||||
func() {},
|
func() {},
|
||||||
forwarded,
|
forwarded,
|
||||||
@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
|||||||
&relayState{},
|
&relayState{},
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
drop,
|
drop,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
|||||||
&relayState{},
|
&relayState{},
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
drop,
|
drop,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
|
|||||||
&relayState{},
|
&relayState{},
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
drop,
|
drop,
|
||||||
nil,
|
nil,
|
||||||
dropped,
|
dropped,
|
||||||
|
|||||||
@ -359,6 +359,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
statusCode,
|
statusCode,
|
||||||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||||||
)
|
)
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
}
|
||||||
|
}
|
||||||
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -456,15 +463,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
cancel()
|
cancel()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
upstreamFirstMessageSent := false
|
||||||
|
firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||||
|
firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage)
|
||||||
|
cancelFirstWrite()
|
||||||
|
if firstWriteErr != nil {
|
||||||
|
return wrapOpenAIWSIngressTurnError(
|
||||||
|
"write_upstream",
|
||||||
|
fmt.Errorf("write first upstream websocket request: %w", firstWriteErr),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
upstreamFirstMessageSent = true
|
||||||
|
|
||||||
|
readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) {
|
||||||
|
for {
|
||||||
|
msgType, payload, readErr := conn.ReadFrame(readCtx)
|
||||||
|
if readErr != nil {
|
||||||
|
return msgType, payload, readErr
|
||||||
|
}
|
||||||
|
if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||||
|
return msgType, payload, nil
|
||||||
|
}
|
||||||
|
if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil {
|
||||||
|
return msgType, payload, writeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
ClientConn: policyClientConn,
|
ClientConn: policyClientConn,
|
||||||
UpstreamConn: upstreamFrameConn,
|
UpstreamConn: upstreamFrameConn,
|
||||||
FirstClientMessage: firstClientMessage,
|
FirstClientMessage: firstClientMessage,
|
||||||
Options: openaiwsv2.RelayOptions{
|
Options: openaiwsv2.RelayOptions{
|
||||||
WriteTimeout: s.openAIWSWriteTimeout(),
|
WriteTimeout: s.openAIWSWriteTimeout(),
|
||||||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||||||
FirstMessageType: coderws.MessageText,
|
FirstMessageType: coderws.MessageText,
|
||||||
|
FirstMessageSent: upstreamFirstMessageSent,
|
||||||
|
StartClientAfterFirstDownstream: true,
|
||||||
|
ReadClientFrame: readNextClientFrame,
|
||||||
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||||||
logOpenAIWSV2Passthrough(
|
logOpenAIWSV2Passthrough(
|
||||||
"usage_parse_failed event_type=%s usage_raw=%s",
|
"usage_parse_failed event_type=%s usage_raw=%s",
|
||||||
@ -507,6 +545,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
hooks.AfterTurn(turnNo, turnResult, nil)
|
hooks.AfterTurn(turnNo, turnResult, nil)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error {
|
||||||
|
if msgType != coderws.MessageText || wroteDownstream {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload)
|
||||||
|
if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw)
|
||||||
|
logOpenAIWSV2Passthrough(
|
||||||
|
"relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s",
|
||||||
|
account.ID,
|
||||||
|
truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen),
|
||||||
|
truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen),
|
||||||
|
)
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
ResponseBody: append([]byte(nil), payload...),
|
||||||
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
}
|
||||||
|
},
|
||||||
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||||||
logOpenAIWSV2Passthrough(
|
logOpenAIWSV2Passthrough(
|
||||||
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user