fix(openai): 修复 WS passthrough 使用记录缺失推理强度和 User-Agent
- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据 - 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力 - 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
This commit is contained in:
parent
48912014a1
commit
23555be380
@ -1233,6 +1233,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||
})
|
||||
|
||||
require.NotNil(t, got.log.UserAgent)
|
||||
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||
require.True(t, got.log.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||
channelMapping: map[string]string{
|
||||
"gpt-5.4-xhigh": "gpt-5.4",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||
"上游首帧应使用渠道映射后的模型")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||
userAgent: testStringPtr(""),
|
||||
})
|
||||
|
||||
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||
require.NotNil(t, got.log.ReasoningEffort)
|
||||
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogCase struct {
|
||||
firstPayload string
|
||||
userAgent *string
|
||||
channelMapping map[string]string
|
||||
}
|
||||
|
||||
type openAIResponsesWSUsageLogResult struct {
|
||||
log *service.UsageLog
|
||||
upstreamFirstPayload []byte
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
if s.account.Platform != platform {
|
||||
return nil, nil
|
||||
}
|
||||
return []service.Account{s.account}, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID != id {
|
||||
return nil, nil
|
||||
}
|
||||
account := s.account
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if s.created != nil {
|
||||
s.created <- log
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||
service.ChannelRepository
|
||||
channels []service.Channel
|
||||
groupPlatforms map[int64]string
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
return s.channels, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
out := make(map[int64]string, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||
out[groupID] = platform
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstreamPayloadCh := make(chan []byte, 1)
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
upstreamErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr != nil {
|
||||
upstreamErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||
return
|
||||
}
|
||||
upstreamPayloadCh <- payload
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||
))
|
||||
cancelWrite()
|
||||
if writeErr != nil {
|
||||
upstreamErrCh <- writeErr
|
||||
return
|
||||
}
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
upstreamErrCh <- nil
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
groupID := int64(4201)
|
||||
account := service.Account{
|
||||
ID: 9901,
|
||||
Name: "openai-ws-passthrough-usage-e2e",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": upstreamServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||
|
||||
var channelSvc *service.ChannelService
|
||||
if len(tc.channelMapping) > 0 {
|
||||
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||
channels: []service.Channel{{
|
||||
ID: 7701,
|
||||
Name: "openai-ws-e2e-channel",
|
||||
Status: service.StatusActive,
|
||||
GroupIDs: []int64{groupID},
|
||||
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||
}},
|
||||
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
nil,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1801,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
if tc.userAgent != nil {
|
||||
headers.Set("User-Agent", *tc.userAgent)
|
||||
}
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||
|
||||
var usageLog *service.UsageLog
|
||||
select {
|
||||
case usageLog = <-usageRepo.created:
|
||||
require.NotNil(t, usageLog)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||
}
|
||||
|
||||
var upstreamFirstPayload []byte
|
||||
select {
|
||||
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case upstreamErr := <-upstreamErrCh:
|
||||
require.NoError(t, upstreamErr)
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待上游 WebSocket 结束超时")
|
||||
}
|
||||
|
||||
return openAIResponsesWSUsageLogResult{
|
||||
log: usageLog,
|
||||
upstreamFirstPayload: upstreamFirstPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func testStringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
|
||||
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
|
||||
}
|
||||
|
||||
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
|
||||
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
|
||||
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
|
||||
require.NoError(t, firstErr)
|
||||
require.Nil(t, firstBlocked)
|
||||
meta.initFromFirstFrame(firstOut)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load())
|
||||
|
||||
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
meta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := meta.requestModelForFrame(payload)
|
||||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||
if model == "" {
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
meta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
}
|
||||
|
||||
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
|
||||
require.NoError(t, errSession)
|
||||
require.Nil(t, blockedSession)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata")
|
||||
|
||||
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errCancel)
|
||||
require.Nil(t, blockedCancel)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
|
||||
|
||||
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
|
||||
require.NoError(t, errFlat)
|
||||
require.Nil(t, blockedFlat)
|
||||
require.NotNil(t, meta.reasoningEffort.Load())
|
||||
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
|
||||
|
||||
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
|
||||
require.NoError(t, errClear)
|
||||
require.Nil(t, blockedClear)
|
||||
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
|
||||
}
|
||||
|
||||
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
|
||||
// "block keeps previous" semantic: when policy returns block on a
|
||||
// response.create frame, that frame is never sent upstream, so billing tier
|
||||
|
||||
@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
|
||||
|
||||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||||
type OpenAIWSIngressHooks struct {
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
// InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
|
||||
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||
InitialRequestModel string
|
||||
BeforeTurn func(turn int) error
|
||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSLogValue(value string) string {
|
||||
|
||||
@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||
require.NotNil(t, result.ServiceTier)
|
||||
require.Equal(t, "priority", *result.ServiceTier)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "high", *result.ReasoningEffort)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("未收到 passthrough turn 结果回调")
|
||||
}
|
||||
|
||||
@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
type openAIWSPassthroughUsageMeta struct {
|
||||
serviceTier atomic.Pointer[string]
|
||||
reasoningEffort atomic.Pointer[string]
|
||||
|
||||
// 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。
|
||||
sessionRequestModel string
|
||||
}
|
||||
|
||||
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
|
||||
meta := &openAIWSPassthroughUsageMeta{
|
||||
sessionRequestModel: strings.TrimSpace(initialRequestModel),
|
||||
}
|
||||
if meta.sessionRequestModel == "" {
|
||||
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
|
||||
m.sessionRequestModel = model
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
|
||||
if m == nil {
|
||||
return openAIWSPassthroughRequestModelForFrame(payload)
|
||||
}
|
||||
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
|
||||
return model
|
||||
}
|
||||
return m.sessionRequestModel
|
||||
}
|
||||
|
||||
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
|
||||
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
|
||||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// silently passed through, defeating the policy on every frame after
|
||||
// the first.
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||
initialRequestModel := ""
|
||||
if hooks != nil {
|
||||
initialRequestModel = hooks.InitialRequestModel
|
||||
}
|
||||
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
|
||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||
if policyErr != nil {
|
||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||
@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
}
|
||||
firstClientMessage = updatedFirst
|
||||
|
||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
|
||||
// usage 上报:filter
|
||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||
@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 service_tier。
|
||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||
// 可直接 Store/Load 而无需额外封装。
|
||||
var requestServiceTierPtr atomic.Pointer[string]
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
usageMeta.updateSessionRequestModel(payload)
|
||||
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
|
||||
// Per-frame model first; if the client omits "model" on a
|
||||
// follow-up frame (legal in Realtime), fall back to the
|
||||
// session-level model captured from the first frame so the
|
||||
@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||
// 多轮 passthrough usage:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 usageMeta,使用
|
||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||
// - 非 response.create 帧(response.cancel /
|
||||
// conversation.item.create / session.update 等)不携带
|
||||
// per-response service_tier,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||
// per-response metadata,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,usage metadata 应保持
|
||||
// 上一轮值。
|
||||
// - policyErr != nil:异常路径,保持上一轮值。
|
||||
// - 不带 service_tier 的 response.create 会让
|
||||
@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
// service_tier 时按 default 处理,billing 应如实反映。
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
},
|
||||
@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
ServiceTier: usageMeta.serviceTier.Load(),
|
||||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user