Merge pull request #2865 from wey-gu/feat/usage-request-context
fix(gateway): preserve usage request context
This commit is contained in:
commit
69e7c4db30
@ -510,7 +510,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
@ -905,7 +905,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
@ -2056,10 +2056,11 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
|
||||
)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
func (h *GatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
task = wrapUsageRecordTaskContext(parent, task)
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
|
||||
@ -292,7 +292,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
|
||||
@ -267,7 +267,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
|
||||
@ -528,7 +528,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
|
||||
@ -274,7 +274,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
||||
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
|
||||
@ -214,7 +214,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -46,6 +47,31 @@ func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedM
|
||||
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
|
||||
}
|
||||
|
||||
func usageRecordContext(parent context.Context, base context.Context) context.Context {
|
||||
if base == nil {
|
||||
base = context.Background()
|
||||
}
|
||||
if parent == nil {
|
||||
return base
|
||||
}
|
||||
if clientRequestID, _ := parent.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
base = context.WithValue(base, ctxkey.ClientRequestID, strings.TrimSpace(clientRequestID))
|
||||
}
|
||||
if requestID, _ := parent.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
base = context.WithValue(base, ctxkey.RequestID, strings.TrimSpace(requestID))
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func wrapUsageRecordTaskContext(parent context.Context, task service.UsageRecordTask) service.UsageRecordTask {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ctx context.Context) {
|
||||
task(usageRecordContext(parent, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
@ -438,7 +464,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
@ -823,7 +849,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
@ -1427,7 +1453,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
@ -1612,10 +1638,11 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
task = wrapUsageRecordTaskContext(parent, task)
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
@ -1634,18 +1661,19 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(parent context.Context, result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
h.submitMandatoryUsageRecordTask(task)
|
||||
h.submitMandatoryUsageRecordTask(parent, task)
|
||||
return
|
||||
}
|
||||
h.submitUsageRecordTask(task)
|
||||
h.submitUsageRecordTask(parent, task)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
|
||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
task = wrapUsageRecordTaskContext(parent, task)
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
||||
return
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
|
||||
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-request-123")
|
||||
parent = context.WithValue(parent, ctxkey.RequestID, "request-456")
|
||||
|
||||
var gotClientRequestID string
|
||||
var gotRequestID string
|
||||
h := &GatewayHandler{}
|
||||
h.submitUsageRecordTask(parent, func(ctx context.Context) {
|
||||
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
|
||||
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
|
||||
})
|
||||
|
||||
require.Equal(t, "client-request-123", gotClientRequestID)
|
||||
require.Equal(t, "request-456", gotRequestID)
|
||||
}
|
||||
|
||||
func TestOpenAISubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
|
||||
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-request-123")
|
||||
parent = context.WithValue(parent, ctxkey.RequestID, "openai-request-456")
|
||||
|
||||
var gotClientRequestID string
|
||||
var gotRequestID string
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.submitUsageRecordTask(parent, func(ctx context.Context) {
|
||||
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
|
||||
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
|
||||
})
|
||||
|
||||
require.Equal(t, "openai-client-request-123", gotClientRequestID)
|
||||
require.Equal(t, "openai-request-456", gotRequestID)
|
||||
}
|
||||
@ -311,7 +311,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
if result != nil {
|
||||
upstreamModel = result.UpstreamModel
|
||||
}
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitMandatoryUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
|
||||
@ -29,7 +29,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
@ -44,7 +44,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
@ -57,7 +57,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
h.submitUsageRecordTask(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
|
||||
@ -66,12 +66,12 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *t
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
@ -82,7 +82,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
@ -97,7 +97,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
@ -110,7 +110,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
h.submitUsageRecordTask(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
|
||||
@ -119,12 +119,12 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
@ -152,7 +152,7 @@ func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallb
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
||||
h.submitMandatoryUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
@ -182,7 +182,7 @@ func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandator
|
||||
pool.Submit(func(ctx context.Context) {})
|
||||
|
||||
var called atomic.Bool
|
||||
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||
h.submitOpenAIUsageRecordTask(context.Background(), &service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
close(release)
|
||||
|
||||
@ -11,6 +11,8 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const clientRequestIDHeader = "X-Client-Request-ID"
|
||||
|
||||
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||
//
|
||||
// This is used by the Ops monitoring module for end-to-end request correlation.
|
||||
@ -21,12 +23,14 @@ func ClientRequestID() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
|
||||
if v, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(v) != "" {
|
||||
c.Header(clientRequestIDHeader, strings.TrimSpace(v))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
c.Header(clientRequestIDHeader, id)
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
||||
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
|
||||
50
backend/internal/server/middleware/client_request_id_test.go
Normal file
50
backend/internal/server/middleware/client_request_id_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientRequestIDGeneratesAndExposesID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ClientRequestID())
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
c.String(http.StatusOK, value)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.NotEmpty(t, w.Body.String())
|
||||
require.Equal(t, w.Body.String(), w.Header().Get(clientRequestIDHeader))
|
||||
}
|
||||
|
||||
func TestClientRequestIDPreservesExistingContextID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ClientRequestID())
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
c.String(http.StatusOK, value)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "existing-client-request-id"))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, "existing-client-request-id", w.Body.String())
|
||||
require.Equal(t, "existing-client-request-id", w.Header().Get(clientRequestIDHeader))
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user