Merge pull request #2170 from deqiying/fix/openai-ws-passthrough-reasoning-effort
fix(openai): 修复 WS passthrough 用量记录缺失 reasoning effort 和 User-AgentFix/OpenAI ws passthrough reasoning effort
This commit is contained in:
commit
ff50b8b6ea
@ -1233,6 +1233,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
hooks := &service.OpenAIWSIngressHooks{
|
hooks := &service.OpenAIWSIngressHooks{
|
||||||
|
InitialRequestModel: reqModel,
|
||||||
BeforeTurn: func(turn int) error {
|
BeforeTurn: func(turn int) error {
|
||||||
if turn == 1 {
|
if turn == 1 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@ -651,6 +652,46 @@ func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailu
|
|||||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||||
|
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||||
|
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"HIGH"}}`,
|
||||||
|
userAgent: testStringPtr("codex_cli_rs/0.125.0 test"),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotNil(t, got.log.UserAgent)
|
||||||
|
require.Equal(t, "codex_cli_rs/0.125.0 test", *got.log.UserAgent)
|
||||||
|
require.NotNil(t, got.log.ReasoningEffort)
|
||||||
|
require.Equal(t, "high", *got.log.ReasoningEffort)
|
||||||
|
require.True(t, got.log.OpenAIWSMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_PassthroughUsageLogInfersReasoningFromInitialRequestModel(t *testing.T) {
|
||||||
|
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||||
|
firstPayload: `{"type":"response.create","model":"gpt-5.4-xhigh","stream":false}`,
|
||||||
|
userAgent: testStringPtr("codex_cli_rs/0.125.0 mapped"),
|
||||||
|
channelMapping: map[string]string{
|
||||||
|
"gpt-5.4-xhigh": "gpt-5.4",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, "gpt-5.4", gjson.GetBytes(got.upstreamFirstPayload, "model").String(),
|
||||||
|
"上游首帧应使用渠道映射后的模型")
|
||||||
|
require.NotNil(t, got.log.ReasoningEffort)
|
||||||
|
require.Equal(t, "xhigh", *got.log.ReasoningEffort,
|
||||||
|
"usage log reasoning effort 必须使用渠道映射前首帧模型后缀推导")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_PassthroughUsageLogLeavesUserAgentNilWhenMissing(t *testing.T) {
|
||||||
|
got := runOpenAIResponsesWebSocketUsageLogCase(t, openAIResponsesWSUsageLogCase{
|
||||||
|
firstPayload: `{"type":"response.create","model":"gpt-5.4","stream":false,"reasoning":{"effort":"medium"}}`,
|
||||||
|
userAgent: testStringPtr(""),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Nil(t, got.log.UserAgent, "空入站 User-Agent 不应由上游握手 UA 或默认 UA 兜底")
|
||||||
|
require.NotNil(t, got.log.ReasoningEffort)
|
||||||
|
require.Equal(t, "medium", *got.log.ReasoningEffort)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@ -796,3 +837,278 @@ func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject
|
|||||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||||
return httptest.NewServer(router)
|
return httptest.NewServer(router)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIResponsesWSUsageLogCase struct {
|
||||||
|
firstPayload string
|
||||||
|
userAgent *string
|
||||||
|
channelMapping map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIResponsesWSUsageLogResult struct {
|
||||||
|
log *service.UsageLog
|
||||||
|
upstreamFirstPayload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIWSUsageHandlerAccountRepoStub struct {
|
||||||
|
service.AccountRepository
|
||||||
|
account service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
if s.account.Platform != platform {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []service.Account{s.account}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||||
|
return s.ListSchedulableByPlatform(ctx, platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
if s.account.ID != id {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
account := s.account
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
created chan *service.UsageLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerUsageLogRepoStub) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||||
|
if s.created != nil {
|
||||||
|
s.created <- log
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIWSUsageHandlerChannelRepoStub struct {
|
||||||
|
service.ChannelRepository
|
||||||
|
channels []service.Channel
|
||||||
|
groupPlatforms map[int64]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerChannelRepoStub) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||||
|
return s.channels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||||
|
out := make(map[int64]string, len(groupIDs))
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if platform := strings.TrimSpace(s.groupPlatforms[groupID]); platform != "" {
|
||||||
|
out[groupID] = platform
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstreamPayloadCh := make(chan []byte, 1)
|
||||||
|
upstreamErrCh := make(chan error, 1)
|
||||||
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
upstreamErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, payload, readErr := conn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
if readErr != nil {
|
||||||
|
upstreamErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
upstreamErrCh <- errors.New("unexpected upstream websocket message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
upstreamPayloadCh <- payload
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
writeErr := conn.Write(writeCtx, coderws.MessageText, []byte(
|
||||||
|
`{"type":"response.completed","response":{"id":"resp_usage_e2e","model":"gpt-5.4","usage":{"input_tokens":2,"output_tokens":1}}}`,
|
||||||
|
))
|
||||||
|
cancelWrite()
|
||||||
|
if writeErr != nil {
|
||||||
|
upstreamErrCh <- writeErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
upstreamErrCh <- nil
|
||||||
|
}))
|
||||||
|
defer upstreamServer.Close()
|
||||||
|
|
||||||
|
groupID := int64(4201)
|
||||||
|
account := service.Account{
|
||||||
|
ID: 9901,
|
||||||
|
Name: "openai-ws-passthrough-usage-e2e",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"base_url": upstreamServer.URL,
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.RunMode = config.RunModeSimple
|
||||||
|
cfg.Default.RateMultiplier = 1
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
accountRepo := &openAIWSUsageHandlerAccountRepoStub{account: account}
|
||||||
|
usageRepo := &openAIWSUsageHandlerUsageLogRepoStub{created: make(chan *service.UsageLog, 1)}
|
||||||
|
|
||||||
|
var channelSvc *service.ChannelService
|
||||||
|
if len(tc.channelMapping) > 0 {
|
||||||
|
channelSvc = service.NewChannelService(&openAIWSUsageHandlerChannelRepoStub{
|
||||||
|
channels: []service.Channel{{
|
||||||
|
ID: 7701,
|
||||||
|
Name: "openai-ws-e2e-channel",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
GroupIDs: []int64{groupID},
|
||||||
|
ModelMapping: map[string]map[string]string{service.PlatformOpenAI: tc.channelMapping},
|
||||||
|
}},
|
||||||
|
groupPlatforms: map[int64]string{groupID: service.PlatformOpenAI},
|
||||||
|
}, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||||
|
gatewaySvc := service.NewOpenAIGatewayService(
|
||||||
|
accountRepo,
|
||||||
|
usageRepo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
service.NewBillingService(cfg, nil),
|
||||||
|
nil,
|
||||||
|
billingCacheSvc,
|
||||||
|
nil,
|
||||||
|
&service.DeferredService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
channelSvc,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache := &concurrencyCacheMock{
|
||||||
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := &OpenAIGatewayHandler{
|
||||||
|
gatewayService: gatewaySvc,
|
||||||
|
billingCacheService: billingCacheSvc,
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 1801,
|
||||||
|
GroupID: &groupID,
|
||||||
|
User: &service.User{ID: 1701, Status: service.StatusActive},
|
||||||
|
}
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||||
|
handlerServer := httptest.NewServer(router)
|
||||||
|
defer handlerServer.Close()
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
if tc.userAgent != nil {
|
||||||
|
headers.Set("User-Agent", *tc.userAgent)
|
||||||
|
}
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(
|
||||||
|
dialCtx,
|
||||||
|
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||||
|
&coderws.DialOptions{HTTPHeader: headers, CompressionMode: coderws.CompressionContextTakeover},
|
||||||
|
)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(tc.firstPayload))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, event, err := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
var usageLog *service.UsageLog
|
||||||
|
select {
|
||||||
|
case usageLog = <-usageRepo.created:
|
||||||
|
require.NotNil(t, usageLog)
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("等待 WebSocket usage log 写入超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreamFirstPayload []byte
|
||||||
|
select {
|
||||||
|
case upstreamFirstPayload = <-upstreamPayloadCh:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("等待上游 WebSocket 首帧超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case upstreamErr := <-upstreamErrCh:
|
||||||
|
require.NoError(t, upstreamErr)
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("等待上游 WebSocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIResponsesWSUsageLogResult{
|
||||||
|
log: usageLog,
|
||||||
|
upstreamFirstPayload: upstreamFirstPayload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStringPtr(v string) *string {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|||||||
@ -972,6 +972,62 @@ func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing
|
|||||||
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
|
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPassthroughUsageMeta_TracksReasoningEffortAcrossTurns(t *testing.T) {
|
||||||
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||||
|
|
||||||
|
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","reasoning":{"effort":"medium"},"service_tier":"priority"}`)
|
||||||
|
meta := newOpenAIWSPassthroughUsageMeta("", firstFrame)
|
||||||
|
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
|
||||||
|
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, capturedSessionModel, firstFrame)
|
||||||
|
require.NoError(t, firstErr)
|
||||||
|
require.Nil(t, firstBlocked)
|
||||||
|
meta.initFromFirstFrame(firstOut)
|
||||||
|
require.NotNil(t, meta.reasoningEffort.Load())
|
||||||
|
require.Equal(t, "medium", *meta.reasoningEffort.Load())
|
||||||
|
|
||||||
|
process := func(payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||||
|
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||||
|
capturedSessionModel = updated
|
||||||
|
}
|
||||||
|
meta.updateSessionRequestModel(payload)
|
||||||
|
requestModelForThisFrame := meta.requestModelForFrame(payload)
|
||||||
|
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||||
|
if model == "" {
|
||||||
|
model = capturedSessionModel
|
||||||
|
}
|
||||||
|
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
|
||||||
|
if policyErr == nil && blocked == nil &&
|
||||||
|
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||||
|
meta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||||
|
}
|
||||||
|
return out, blocked, policyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
_, blockedSession, errSession := process([]byte(`{"type":"session.update","session":{"model":"gpt-5-high"}}`))
|
||||||
|
require.NoError(t, errSession)
|
||||||
|
require.Nil(t, blockedSession)
|
||||||
|
require.NotNil(t, meta.reasoningEffort.Load())
|
||||||
|
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "session.update 只刷新后续 fallback model,不覆盖当前 turn metadata")
|
||||||
|
|
||||||
|
_, blockedCancel, errCancel := process([]byte(`{"type":"response.cancel","reasoning_effort":"x-high"}`))
|
||||||
|
require.NoError(t, errCancel)
|
||||||
|
require.Nil(t, blockedCancel)
|
||||||
|
require.NotNil(t, meta.reasoningEffort.Load())
|
||||||
|
require.Equal(t, "medium", *meta.reasoningEffort.Load(), "非 response.create 帧不能污染当前 turn metadata")
|
||||||
|
|
||||||
|
_, blockedFlat, errFlat := process([]byte(`{"type":"response.create","reasoning_effort":"x-high"}`))
|
||||||
|
require.NoError(t, errFlat)
|
||||||
|
require.Nil(t, blockedFlat)
|
||||||
|
require.NotNil(t, meta.reasoningEffort.Load())
|
||||||
|
require.Equal(t, "xhigh", *meta.reasoningEffort.Load(), "flat reasoning_effort 必须进入 passthrough usage metadata")
|
||||||
|
|
||||||
|
_, blockedClear, errClear := process([]byte(`{"type":"response.create","model":"gpt-4o"}`))
|
||||||
|
require.NoError(t, errClear)
|
||||||
|
require.Nil(t, blockedClear)
|
||||||
|
require.Nil(t, meta.reasoningEffort.Load(), "新的 response.create 无 effort 且无可推导后缀时必须清空旧值")
|
||||||
|
}
|
||||||
|
|
||||||
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
|
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
|
||||||
// "block keeps previous" semantic: when policy returns block on a
|
// "block keeps previous" semantic: when policy returns block on a
|
||||||
// response.create frame, that frame is never sent upstream, so billing tier
|
// response.create frame, that frame is never sent upstream, so billing tier
|
||||||
|
|||||||
@ -219,8 +219,11 @@ func (e *OpenAIWSClientCloseError) Reason() string {
|
|||||||
|
|
||||||
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
|
||||||
type OpenAIWSIngressHooks struct {
|
type OpenAIWSIngressHooks struct {
|
||||||
BeforeTurn func(turn int) error
|
// InitialRequestModel 是首帧渠道映射前的请求模型,只用于 usage metadata
|
||||||
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
// 的 reasoning effort 后缀推导,禁止用于上游请求或计费模型。
|
||||||
|
InitialRequestModel string
|
||||||
|
BeforeTurn func(turn int) error
|
||||||
|
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeOpenAIWSLogValue(value string) string {
|
func normalizeOpenAIWSLogValue(value string) string {
|
||||||
|
|||||||
@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`))
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast","reasoning":{"effort":"HIGH"}}`))
|
||||||
cancelWrite()
|
cancelWrite()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -431,6 +431,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
|||||||
require.Equal(t, 3, result.Usage.OutputTokens)
|
require.Equal(t, 3, result.Usage.OutputTokens)
|
||||||
require.NotNil(t, result.ServiceTier)
|
require.NotNil(t, result.ServiceTier)
|
||||||
require.Equal(t, "priority", *result.ServiceTier)
|
require.Equal(t, "priority", *result.ServiceTier)
|
||||||
|
require.NotNil(t, result.ReasoningEffort)
|
||||||
|
require.Equal(t, "high", *result.ReasoningEffort)
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatal("未收到 passthrough turn 结果回调")
|
t.Fatal("未收到 passthrough turn 结果回调")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -124,6 +124,73 @@ func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []
|
|||||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIWSPassthroughUsageMeta struct {
|
||||||
|
serviceTier atomic.Pointer[string]
|
||||||
|
reasoningEffort atomic.Pointer[string]
|
||||||
|
|
||||||
|
// 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。
|
||||||
|
sessionRequestModel string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
|
||||||
|
meta := &openAIWSPassthroughUsageMeta{
|
||||||
|
sessionRequestModel: strings.TrimSpace(initialRequestModel),
|
||||||
|
}
|
||||||
|
if meta.sessionRequestModel == "" {
|
||||||
|
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||||
|
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
|
||||||
|
m.sessionRequestModel = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
|
||||||
|
if m == nil {
|
||||||
|
return openAIWSPassthroughRequestModelForFrame(payload)
|
||||||
|
}
|
||||||
|
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
return m.sessionRequestModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||||||
|
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
|
||||||
|
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
|
||||||
|
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||||
|
}
|
||||||
|
|
||||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||||
|
|
||||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||||
@ -204,6 +271,11 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
// silently passed through, defeating the policy on every frame after
|
// silently passed through, defeating the policy on every frame after
|
||||||
// the first.
|
// the first.
|
||||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||||
|
initialRequestModel := ""
|
||||||
|
if hooks != nil {
|
||||||
|
initialRequestModel = hooks.InitialRequestModel
|
||||||
|
}
|
||||||
|
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
|
||||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||||
@ -226,7 +298,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
}
|
}
|
||||||
firstClientMessage = updatedFirst
|
firstClientMessage = updatedFirst
|
||||||
|
|
||||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
|
||||||
|
// usage 上报:filter
|
||||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||||
@ -237,11 +310,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||||
// goroutine)之间同步当前 turn 的 service_tier。
|
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||||
// 可直接 Store/Load 而无需额外封装。
|
|
||||||
var requestServiceTierPtr atomic.Pointer[string]
|
|
||||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
|
||||||
|
|
||||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -327,6 +397,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||||
capturedSessionModel = updated
|
capturedSessionModel = updated
|
||||||
}
|
}
|
||||||
|
usageMeta.updateSessionRequestModel(payload)
|
||||||
|
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
|
||||||
// Per-frame model first; if the client omits "model" on a
|
// Per-frame model first; if the client omits "model" on a
|
||||||
// follow-up frame (legal in Realtime), fall back to the
|
// follow-up frame (legal in Realtime), fall back to the
|
||||||
// session-level model captured from the first frame so the
|
// session-level model captured from the first frame so the
|
||||||
@ -337,14 +409,14 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
model = capturedSessionModel
|
model = capturedSessionModel
|
||||||
}
|
}
|
||||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
// 多轮 passthrough usage:仅在成功(non-block / non-err)
|
||||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
// 的 response.create 帧上更新 usageMeta,使用
|
||||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||||
// - 非 response.create 帧(response.cancel /
|
// - 非 response.create 帧(response.cancel /
|
||||||
// conversation.item.create / session.update 等)不携带
|
// conversation.item.create / session.update 等)不携带
|
||||||
// per-response service_tier,不应覆盖前一轮值。
|
// per-response metadata,不应覆盖前一轮值。
|
||||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
// - blocked != nil:该帧不会发送上游,usage metadata 应保持
|
||||||
// 上一轮值。
|
// 上一轮值。
|
||||||
// - policyErr != nil:异常路径,保持上一轮值。
|
// - policyErr != nil:异常路径,保持上一轮值。
|
||||||
// - 不带 service_tier 的 response.create 会让
|
// - 不带 service_tier 的 response.create 会让
|
||||||
@ -353,7 +425,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
// service_tier 时按 default 处理,billing 应如实反映。
|
// service_tier 时按 default 处理,billing 应如实反映。
|
||||||
if policyErr == nil && blocked == nil &&
|
if policyErr == nil && blocked == nil &&
|
||||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||||||
}
|
}
|
||||||
return out, blocked, policyErr
|
return out, blocked, policyErr
|
||||||
},
|
},
|
||||||
@ -397,7 +469,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
},
|
},
|
||||||
Model: turn.RequestModel,
|
Model: turn.RequestModel,
|
||||||
ServiceTier: requestServiceTierPtr.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
|
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||||
Stream: true,
|
Stream: true,
|
||||||
OpenAIWSMode: true,
|
OpenAIWSMode: true,
|
||||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
@ -445,7 +518,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
},
|
},
|
||||||
Model: relayResult.RequestModel,
|
Model: relayResult.RequestModel,
|
||||||
ServiceTier: requestServiceTierPtr.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
|
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||||||
Stream: true,
|
Stream: true,
|
||||||
OpenAIWSMode: true,
|
OpenAIWSMode: true,
|
||||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user