Preserve usage request context

This commit is contained in:
Wey Gu 2026-05-28 22:08:02 +08:00
parent 8c1a07852c
commit 2bd3125d0f
12 changed files with 154 additions and 30 deletions

View File

@ -510,7 +510,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
@ -905,7 +905,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
@ -2056,10 +2056,11 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
)
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
func (h *GatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return

View File

@ -292,7 +292,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -267,7 +267,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -528,7 +528,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -273,7 +273,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -220,7 +220,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -12,6 +12,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -46,6 +47,31 @@ func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedM
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
func usageRecordContext(parent context.Context, base context.Context) context.Context {
if base == nil {
base = context.Background()
}
if parent == nil {
return base
}
if clientRequestID, _ := parent.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
base = context.WithValue(base, ctxkey.ClientRequestID, strings.TrimSpace(clientRequestID))
}
if requestID, _ := parent.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
base = context.WithValue(base, ctxkey.RequestID, strings.TrimSpace(requestID))
}
return base
}
func wrapUsageRecordTaskContext(parent context.Context, task service.UsageRecordTask) service.UsageRecordTask {
if task == nil {
return nil
}
return func(ctx context.Context) {
task(usageRecordContext(parent, ctx))
}
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@ -437,7 +463,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -821,7 +847,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -1424,7 +1450,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -1609,10 +1635,11 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
@ -1631,18 +1658,19 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
task(ctx)
}
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(parent context.Context, result *service.OpenAIForwardResult, task service.UsageRecordTask) {
if result != nil && result.ImageCount > 0 {
h.submitMandatoryUsageRecordTask(task)
h.submitMandatoryUsageRecordTask(parent, task)
return
}
h.submitUsageRecordTask(task)
h.submitUsageRecordTask(parent, task)
}
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
return

View File

@ -0,0 +1,41 @@
package handler
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-request-123")
parent = context.WithValue(parent, ctxkey.RequestID, "request-456")
var gotClientRequestID string
var gotRequestID string
h := &GatewayHandler{}
h.submitUsageRecordTask(parent, func(ctx context.Context) {
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
})
require.Equal(t, "client-request-123", gotClientRequestID)
require.Equal(t, "request-456", gotRequestID)
}
func TestOpenAISubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-request-123")
parent = context.WithValue(parent, ctxkey.RequestID, "openai-request-456")
var gotClientRequestID string
var gotRequestID string
h := &OpenAIGatewayHandler{}
h.submitUsageRecordTask(parent, func(ctx context.Context) {
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
})
require.Equal(t, "openai-client-request-123", gotClientRequestID)
require.Equal(t, "openai-request-456", gotRequestID)
}

View File

@ -311,7 +311,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if result != nil {
upstreamModel = result.UpstreamModel
}
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
h.submitMandatoryUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -29,7 +29,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
close(done)
})
@ -44,7 +44,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
@ -57,7 +57,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
h.submitUsageRecordTask(context.Background(), nil)
})
}
@ -66,12 +66,12 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *t
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
@ -82,7 +82,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
close(done)
})
@ -97,7 +97,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
@ -110,7 +110,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
h.submitUsageRecordTask(context.Background(), nil)
})
}
@ -119,12 +119,12 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
@ -152,7 +152,7 @@ func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallb
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
h.submitMandatoryUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
close(release)
@ -182,7 +182,7 @@ func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandator
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(context.Background(), &service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
called.Store(true)
})
close(release)

View File

@ -11,6 +11,8 @@ import (
"go.uber.org/zap"
)
const clientRequestIDHeader = "X-Client-Request-ID"
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
//
// This is used by the Ops monitoring module for end-to-end request correlation.
@ -21,12 +23,14 @@ func ClientRequestID() gin.HandlerFunc {
return
}
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
if v, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(v) != "" {
c.Header(clientRequestIDHeader, strings.TrimSpace(v))
c.Next()
return
}
id := uuid.New().String()
c.Header(clientRequestIDHeader, id)
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
ctx = logger.IntoContext(ctx, requestLogger)

View File

@ -0,0 +1,50 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestClientRequestIDGeneratesAndExposesID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ClientRequestID())
router.GET("/", func(c *gin.Context) {
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
c.String(http.StatusOK, value)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.NotEmpty(t, w.Body.String())
require.Equal(t, w.Body.String(), w.Header().Get(clientRequestIDHeader))
}
func TestClientRequestIDPreservesExistingContextID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ClientRequestID())
router.GET("/", func(c *gin.Context) {
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
c.String(http.StatusOK, value)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "existing-client-request-id"))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "existing-client-request-id", w.Body.String())
require.Equal(t, "existing-client-request-id", w.Header().Get(clientRequestIDHeader))
}