diff --git a/backend/internal/handler/concurrency_error_response.go b/backend/internal/handler/concurrency_error_response.go new file mode 100644 index 00000000..52abf735 --- /dev/null +++ b/backend/internal/handler/concurrency_error_response.go @@ -0,0 +1,27 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +const statusClientClosedRequest = 499 + +func concurrencyErrorResponse(err error, slotType string) (int, string, string) { + var concurrencyErr *ConcurrencyError + if errors.As(err, &concurrencyErr) { + if concurrencyErr.SlotType != "" { + slotType = concurrencyErr.SlotType + } + return http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType) + } + + if errors.Is(err, context.Canceled) { + return statusClientClosedRequest, "api_error", "context canceled" + } + + return http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable, please retry later" +} diff --git a/backend/internal/handler/concurrency_error_response_test.go b/backend/internal/handler/concurrency_error_response_test.go new file mode 100644 index 00000000..a2e6b9ab --- /dev/null +++ b/backend/internal/handler/concurrency_error_response_test.go @@ -0,0 +1,63 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConcurrencyErrorResponse(t *testing.T) { + tests := []struct { + name string + err error + slotType string + wantStatus int + wantType string + wantMessage string + }{ + { + name: "true concurrency timeout remains rate limit", + err: &ConcurrencyError{SlotType: "account", IsTimeout: true}, + slotType: "user", + wantStatus: http.StatusTooManyRequests, + wantType: "rate_limit_error", + wantMessage: "Concurrency limit exceeded for account, please retry later", + }, + { + name: "client cancellation is not classified as concurrency limit", + err: context.Canceled, + slotType: "user", + wantStatus: statusClientClosedRequest, + wantType: "api_error", + wantMessage: "context canceled", + }, + { + name: "deadline exceeded is service unavailable", + err: context.DeadlineExceeded, + slotType: "user", + wantStatus: http.StatusServiceUnavailable, + wantType: "api_error", + wantMessage: "Service temporarily unavailable, please retry later", + }, + { + name: "redis acquire error is service unavailable", + err: errors.New("redis unavailable"), + slotType: "user", + wantStatus: http.StatusServiceUnavailable, + wantType: "api_error", + wantMessage: "Service temporarily unavailable, please retry later", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status, errType, message := concurrencyErrorResponse(tt.err, tt.slotType) + require.Equal(t, tt.wantStatus, status) + require.Equal(t, tt.wantType, errType) + require.Equal(t, tt.wantMessage, message) + }) + } +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4695a791..a24611f9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -1471,10 +1471,10 @@ func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, su return min } -// handleConcurrencyError handles concurrency-related errors with proper 429 response +// handleConcurrencyError handles concurrency-related acquire errors. func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", - fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) + status, errType, message := concurrencyErrorResponse(err, slotType) + h.handleStreamingAwareError(c, status, errType, message, streamStarted) } func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 09e6c09b..e4897502 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -336,6 +336,9 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType for { select { case <-ctx.Done(): + if parentErr := c.Request.Context().Err(); parentErr != nil { + return nil, parentErr + } return nil, &ConcurrencyError{ SlotType: slotType, IsTimeout: true, diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 4a677199..d57c396c 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -280,6 +280,25 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { }) } +func TestWaitForSlotWithPingTimeout_ParentContextCanceled(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + reqCtx, cancel := context.WithCancel(c.Request.Context()) + c.Request = c.Request.WithContext(reqCtx) + cancel() + + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) + require.Nil(t, release) + require.ErrorIs(t, err, context.Canceled) + var cErr *ConcurrencyError + require.False(t, errors.As(err, &cErr)) +} + func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { errCache := &helperConcurrencyCacheStubWithError{ err: errors.New("redis unavailable"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 1d661748..979aaa1c 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1688,10 +1688,10 @@ func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, stream return nil, false } -// handleConcurrencyError handles concurrency-related errors with proper 429 response +// handleConcurrencyError handles concurrency-related acquire errors. func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", - fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) + status, errType, message := concurrencyErrorResponse(err, slotType) + h.handleStreamingAwareError(c, status, errType, message, streamStarted) } func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {