From 2bd3125d0fe23515fd42ee3c6651bbe38e47593c Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Thu, 28 May 2026 22:08:02 +0800 Subject: [PATCH] Preserve usage request context --- backend/internal/handler/gateway_handler.go | 7 +-- .../gateway_handler_chat_completions.go | 2 +- .../handler/gateway_handler_responses.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 2 +- .../handler/openai_chat_completions.go | 2 +- backend/internal/handler/openai_embeddings.go | 2 +- .../handler/openai_gateway_handler.go | 44 +++++++++++++--- .../openai_gateway_usage_context_test.go | 41 +++++++++++++++ backend/internal/handler/openai_images.go | 2 +- .../handler/usage_record_submit_task_test.go | 24 ++++----- .../server/middleware/client_request_id.go | 6 ++- .../middleware/client_request_id_test.go | 50 +++++++++++++++++++ 12 files changed, 154 insertions(+), 30 deletions(-) create mode 100644 backend/internal/handler/openai_gateway_usage_context_test.go create mode 100644 backend/internal/server/middleware/client_request_id_test.go diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4695a791..a6749191 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -510,7 +510,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, @@ -905,7 +905,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, @@ -2056,10 +2056,11 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger ) } -func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { +func (h *GatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) { if task == nil { return } + task = wrapUsageRecordTaskContext(parent, task) if h.usageRecordWorkerPool != nil { h.usageRecordWorkerPool.Submit(task) return diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index acbdc261..daf6e6ea 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -292,7 +292,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, QuotaPlatform: quotaPlatform, diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 6a083f31..f57b9989 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -267,7 +267,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, QuotaPlatform: quotaPlatform, diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 27ea4404..0b33ca3e 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -528,7 +528,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, QuotaPlatform: quotaPlatform, diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 17f0d47e..9f63ef1f 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -273,7 +273,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account) - h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go index bbb67044..b64ac41d 100644 --- a/backend/internal/handler/openai_embeddings.go +++ b/backend/internal/handler/openai_embeddings.go @@ -220,7 +220,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a51eee86..86503f30 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -12,6 +12,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -46,6 +47,31 @@ func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedM return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) } +func usageRecordContext(parent context.Context, base context.Context) context.Context { + if base == nil { + base = context.Background() + } + if parent == nil { + return base + } + if clientRequestID, _ := parent.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + base = context.WithValue(base, ctxkey.ClientRequestID, strings.TrimSpace(clientRequestID)) + } + if requestID, _ := parent.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + base = context.WithValue(base, ctxkey.RequestID, strings.TrimSpace(requestID)) + } + return base +} + +func wrapUsageRecordTaskContext(parent context.Context, task service.UsageRecordTask) service.UsageRecordTask { + if task == nil { + return nil + } + return func(ctx context.Context) { + task(usageRecordContext(parent, ctx)) + } +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -437,7 +463,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, @@ -821,7 +847,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, @@ -1424,7 +1450,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { + h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, @@ -1609,10 +1635,11 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) { } } -func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { +func (h *OpenAIGatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) { if task == nil { return } + task = wrapUsageRecordTaskContext(parent, task) if h.usageRecordWorkerPool != nil { h.usageRecordWorkerPool.Submit(task) return @@ -1631,18 +1658,19 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas task(ctx) } -func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) { +func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(parent context.Context, result *service.OpenAIForwardResult, task service.UsageRecordTask) { if result != nil && result.ImageCount > 0 { - h.submitMandatoryUsageRecordTask(task) + h.submitMandatoryUsageRecordTask(parent, task) return } - h.submitUsageRecordTask(task) + h.submitUsageRecordTask(parent, task) } -func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) { +func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(parent context.Context, task service.UsageRecordTask) { if task == nil { return } + task = wrapUsageRecordTaskContext(parent, task) if h.usageRecordWorkerPool != nil { if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped { return diff --git a/backend/internal/handler/openai_gateway_usage_context_test.go b/backend/internal/handler/openai_gateway_usage_context_test.go new file mode 100644 index 00000000..7091c9c0 --- /dev/null +++ b/backend/internal/handler/openai_gateway_usage_context_test.go @@ -0,0 +1,41 @@ +package handler + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestSubmitUsageRecordTaskCopiesRequestContext(t *testing.T) { + parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-request-123") + parent = context.WithValue(parent, ctxkey.RequestID, "request-456") + + var gotClientRequestID string + var gotRequestID string + h := &GatewayHandler{} + h.submitUsageRecordTask(parent, func(ctx context.Context) { + gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string) + gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string) + }) + + require.Equal(t, "client-request-123", gotClientRequestID) + require.Equal(t, "request-456", gotRequestID) +} + +func TestOpenAISubmitUsageRecordTaskCopiesRequestContext(t *testing.T) { + parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-request-123") + parent = context.WithValue(parent, ctxkey.RequestID, "openai-request-456") + + var gotClientRequestID string + var gotRequestID string + h := &OpenAIGatewayHandler{} + h.submitUsageRecordTask(parent, func(ctx context.Context) { + gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string) + gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string) + }) + + require.Equal(t, "openai-client-request-123", gotClientRequestID) + require.Equal(t, "openai-request-456", gotRequestID) +} diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index bbb08014..36339d4b 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -311,7 +311,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { if result != nil { upstreamModel = result.UpstreamModel } - h.submitMandatoryUsageRecordTask(func(ctx context.Context) { + h.submitMandatoryUsageRecordTask(c.Request.Context(), func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index e4c2837a..ebe5c3df 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -29,7 +29,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { h := &GatewayHandler{usageRecordWorkerPool: pool} done := make(chan struct{}) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { close(done) }) @@ -44,7 +44,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing. h := &GatewayHandler{} var called atomic.Bool - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { if _, ok := ctx.Deadline(); !ok { t.Fatal("expected deadline in fallback context") } @@ -57,7 +57,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing. func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &GatewayHandler{} require.NotPanics(t, func() { - h.submitUsageRecordTask(nil) + h.submitUsageRecordTask(context.Background(), nil) }) } @@ -66,12 +66,12 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *t var called atomic.Bool require.NotPanics(t, func() { - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { panic("usage task panic") }) }) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { called.Store(true) }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") @@ -82,7 +82,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} done := make(chan struct{}) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { close(done) }) @@ -97,7 +97,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te h := &OpenAIGatewayHandler{} var called atomic.Bool - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { if _, ok := ctx.Deadline(); !ok { t.Fatal("expected deadline in fallback context") } @@ -110,7 +110,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &OpenAIGatewayHandler{} require.NotPanics(t, func() { - h.submitUsageRecordTask(nil) + h.submitUsageRecordTask(context.Background(), nil) }) } @@ -119,12 +119,12 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere var called atomic.Bool require.NotPanics(t, func() { - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { panic("usage task panic") }) }) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitUsageRecordTask(context.Background(), func(ctx context.Context) { called.Store(true) }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") @@ -152,7 +152,7 @@ func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallb pool.Submit(func(ctx context.Context) {}) var called atomic.Bool - h.submitMandatoryUsageRecordTask(func(ctx context.Context) { + h.submitMandatoryUsageRecordTask(context.Background(), func(ctx context.Context) { called.Store(true) }) close(release) @@ -182,7 +182,7 @@ func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandator pool.Submit(func(ctx context.Context) {}) var called atomic.Bool - h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(context.Background(), &service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) { called.Store(true) }) close(release) diff --git a/backend/internal/server/middleware/client_request_id.go b/backend/internal/server/middleware/client_request_id.go index 6838d6af..5f886646 100644 --- a/backend/internal/server/middleware/client_request_id.go +++ b/backend/internal/server/middleware/client_request_id.go @@ -11,6 +11,8 @@ import ( "go.uber.org/zap" ) +const clientRequestIDHeader = "X-Client-Request-ID" + // ClientRequestID ensures every request has a unique client_request_id in request.Context(). // // This is used by the Ops monitoring module for end-to-end request correlation. @@ -21,12 +23,14 @@ func ClientRequestID() gin.HandlerFunc { return } - if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil { + if v, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(v) != "" { + c.Header(clientRequestIDHeader, strings.TrimSpace(v)) c.Next() return } id := uuid.New().String() + c.Header(clientRequestIDHeader, id) ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id) requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id))) ctx = logger.IntoContext(ctx, requestLogger) diff --git a/backend/internal/server/middleware/client_request_id_test.go b/backend/internal/server/middleware/client_request_id_test.go new file mode 100644 index 00000000..394c1612 --- /dev/null +++ b/backend/internal/server/middleware/client_request_id_test.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClientRequestIDGeneratesAndExposesID(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ClientRequestID()) + router.GET("/", func(c *gin.Context) { + value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + c.String(http.StatusOK, value) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.NotEmpty(t, w.Body.String()) + require.Equal(t, w.Body.String(), w.Header().Get(clientRequestIDHeader)) +} + +func TestClientRequestIDPreservesExistingContextID(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ClientRequestID()) + router.GET("/", func(c *gin.Context) { + value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + c.String(http.StatusOK, value) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "existing-client-request-id")) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, "existing-client-request-id", w.Body.String()) + require.Equal(t, "existing-client-request-id", w.Header().Get(clientRequestIDHeader)) +}