chore: merge upstream Wei-Shaw/sub2api v0.1.133

This commit is contained in:
win 2026-05-29 17:48:27 +08:00
commit a420179abb
98 changed files with 7080 additions and 2190 deletions

View File

@ -1 +1 @@
0.1.132
0.1.133

View File

@ -199,7 +199,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
opsService := service.ProvideOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink, settingService)
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
if err != nil {
return nil, err

View File

@ -22,18 +22,16 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformWindsurf = "windsurf"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分)
AccountTypeWindsurfSession = "windsurf-session" // Windsurf Session 类型账号(邮箱密码登录获取的 session token + api_key
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分)
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI
)
// Redeem type constants
@ -74,7 +72,8 @@ const (
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
var DefaultAntigravityModelMapping = map[string]string{
// Claude 白名单
"claude-opus-4-7": "claude-opus-4-6", // 官方模型
"claude-opus-4-8": "claude-opus-4-8", // 官方模型
"claude-opus-4-7": "claude-opus-4-7", // 官方模型
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
@ -124,6 +123,7 @@ var DefaultAntigravityModelMapping = map[string]string{
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
"claude-opus-4-8": "us.anthropic.claude-opus-4-8-v1",
"claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1",
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",

View File

@ -24,3 +24,27 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
}
}
}
func TestDefaultAntigravityModelMapping_ContainsOpus48(t *testing.T) {
t.Parallel()
got, ok := DefaultAntigravityModelMapping["claude-opus-4-8"]
if !ok {
t.Fatal("expected mapping for claude-opus-4-8 to exist")
}
if got != "claude-opus-4-8" {
t.Fatalf("unexpected claude-opus-4-8 mapping: got %q", got)
}
}
func TestDefaultBedrockModelMapping_ContainsOpus48(t *testing.T) {
t.Parallel()
got, ok := DefaultBedrockModelMapping["claude-opus-4-8"]
if !ok {
t.Fatal("expected Bedrock mapping for claude-opus-4-8 to exist")
}
if got != "us.anthropic.claude-opus-4-8-v1" {
t.Fatalf("unexpected Bedrock claude-opus-4-8 mapping: got %q", got)
}
}

View File

@ -256,6 +256,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
OpenAIAllowClaudeCodeCodexPlugin: settings.OpenAIAllowClaudeCodeCodexPlugin,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
@ -584,6 +585,7 @@ type UpdateSettingsRequest struct {
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
OpenAIAllowClaudeCodeCodexPlugin *bool `json:"openai_allow_claude_code_codex_plugin"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
@ -1655,6 +1657,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.OpenAICodexUserAgent
}(),
OpenAIAllowClaudeCodeCodexPlugin: func() bool {
if req.OpenAIAllowClaudeCodeCodexPlugin != nil {
return *req.OpenAIAllowClaudeCodeCodexPlugin
}
return previousSettings.OpenAIAllowClaudeCodeCodexPlugin
}(),
PaymentVisibleMethodAlipaySource: func() string {
if req.PaymentVisibleMethodAlipaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
@ -2031,6 +2039,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
OpenAIAllowClaudeCodeCodexPlugin: updatedSettings.OpenAIAllowClaudeCodeCodexPlugin,
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
@ -2500,6 +2509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
changed = append(changed, "openai_codex_user_agent")
}
if before.OpenAIAllowClaudeCodeCodexPlugin != after.OpenAIAllowClaudeCodeCodexPlugin {
changed = append(changed, "openai_allow_claude_code_codex_plugin")
}
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
changed = append(changed, "payment_visible_method_alipay_source")
}

View File

@ -2,6 +2,7 @@ package admin
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
@ -17,12 +18,18 @@ import (
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
updateSvc systemUpdateService
lockSvc *service.SystemOperationLockService
}
type systemUpdateService interface {
CheckUpdate(ctx context.Context, force bool) (*service.UpdateInfo, error)
PerformUpdate(ctx context.Context) error
Rollback() error
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
func NewSystemHandler(updateSvc systemUpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{
updateSvc: updateSvc,
lockSvc: lockSvc,
@ -67,6 +74,21 @@ func (h *SystemHandler) PerformUpdate(c *gin.Context) {
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
if errors.Is(err, service.ErrNoUpdateAvailable) {
info, checkErr := h.updateSvc.CheckUpdate(ctx, false)
if checkErr != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, checkErr
}
succeeded = true
return gin.H{
"message": "Already up to date",
"already_up_to_date": true,
"current_version": info.CurrentVersion,
"latest_version": info.LatestVersion,
"operation_id": lock.OperationID(),
}, nil
}
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}

View File

@ -0,0 +1,144 @@
//go:build unit
package admin
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type systemHandlerUpdateServiceStub struct {
performErr error
updateInfo *service.UpdateInfo
checkErr error
checkForces []bool
performCall int
}
func (s *systemHandlerUpdateServiceStub) CheckUpdate(_ context.Context, force bool) (*service.UpdateInfo, error) {
s.checkForces = append(s.checkForces, force)
return s.updateInfo, s.checkErr
}
func (s *systemHandlerUpdateServiceStub) PerformUpdate(context.Context) error {
s.performCall++
return s.performErr
}
func (s *systemHandlerUpdateServiceStub) Rollback() error {
return nil
}
type systemUpdateResponseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
Message string `json:"message"`
AlreadyUpToDate bool `json:"already_up_to_date"`
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
OperationID string `json:"operation_id"`
} `json:"data"`
}
type systemUpdateErrorEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
}
func newSystemHandlerTestRouter(t *testing.T, updateSvc *systemHandlerUpdateServiceStub, repo *memoryIdempotencyRepoStub) *gin.Engine {
t.Helper()
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(nil)
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
lockSvc := service.NewSystemOperationLockService(repo, service.IdempotencyConfig{
ProcessingTimeout: time.Second,
SystemOperationTTL: time.Minute,
})
handler := NewSystemHandler(updateSvc, lockSvc)
router := gin.New()
router.POST("/api/v1/admin/system/update", handler.PerformUpdate)
return router
}
func requireSystemLockStatus(t *testing.T, repo *memoryIdempotencyRepoStub, wantStatus string) {
t.Helper()
repo.mu.Lock()
defer repo.mu.Unlock()
for _, record := range repo.data {
if record.Status == wantStatus {
return
}
}
t.Fatalf("system lock status %q not found in records: %#v", wantStatus, repo.data)
}
func TestSystemHandlerPerformUpdateAlreadyUpToDateReturnsOK(t *testing.T) {
updateSvc := &systemHandlerUpdateServiceStub{
performErr: service.ErrNoUpdateAvailable,
updateInfo: &service.UpdateInfo{
CurrentVersion: "0.1.132",
LatestVersion: "0.1.132",
HasUpdate: false,
},
}
repo := newMemoryIdempotencyRepoStub()
router := newSystemHandlerTestRouter(t, updateSvc, repo)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/system/update", nil)
req.Header.Set("Idempotency-Key", "already-up-to-date")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, updateSvc.performCall)
require.Equal(t, []bool{false}, updateSvc.checkForces)
requireSystemLockStatus(t, repo, service.IdempotencyStatusSucceeded)
var body systemUpdateResponseEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body))
require.Equal(t, 0, body.Code)
require.Equal(t, "success", body.Message)
require.Equal(t, "Already up to date", body.Data.Message)
require.True(t, body.Data.AlreadyUpToDate)
require.Equal(t, "0.1.132", body.Data.CurrentVersion)
require.Equal(t, "0.1.132", body.Data.LatestVersion)
require.NotEmpty(t, body.Data.OperationID)
}
func TestSystemHandlerPerformUpdateFailureStillReturnsInternalError(t *testing.T) {
updateSvc := &systemHandlerUpdateServiceStub{
performErr: errors.New("download failed"),
}
repo := newMemoryIdempotencyRepoStub()
router := newSystemHandlerTestRouter(t, updateSvc, repo)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/system/update", nil)
req.Header.Set("Idempotency-Key", "real-failure")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Equal(t, 1, updateSvc.performCall)
require.Empty(t, updateSvc.checkForces)
requireSystemLockStatus(t, repo, service.IdempotencyStatusFailedRetryable)
var body systemUpdateErrorEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body))
require.Equal(t, http.StatusInternalServerError, body.Code)
require.Equal(t, "internal error", body.Message)
}

View File

@ -0,0 +1,27 @@
package handler
import (
"context"
"errors"
"fmt"
"net/http"
)
const statusClientClosedRequest = 499
func concurrencyErrorResponse(err error, slotType string) (int, string, string) {
var concurrencyErr *ConcurrencyError
if errors.As(err, &concurrencyErr) {
if concurrencyErr.SlotType != "" {
slotType = concurrencyErr.SlotType
}
return http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType)
}
if errors.Is(err, context.Canceled) {
return statusClientClosedRequest, "api_error", "context canceled"
}
return http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable, please retry later"
}

View File

@ -0,0 +1,63 @@
package handler
import (
"context"
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestConcurrencyErrorResponse(t *testing.T) {
tests := []struct {
name string
err error
slotType string
wantStatus int
wantType string
wantMessage string
}{
{
name: "true concurrency timeout remains rate limit",
err: &ConcurrencyError{SlotType: "account", IsTimeout: true},
slotType: "user",
wantStatus: http.StatusTooManyRequests,
wantType: "rate_limit_error",
wantMessage: "Concurrency limit exceeded for account, please retry later",
},
{
name: "client cancellation is not classified as concurrency limit",
err: context.Canceled,
slotType: "user",
wantStatus: statusClientClosedRequest,
wantType: "api_error",
wantMessage: "context canceled",
},
{
name: "deadline exceeded is service unavailable",
err: context.DeadlineExceeded,
slotType: "user",
wantStatus: http.StatusServiceUnavailable,
wantType: "api_error",
wantMessage: "Service temporarily unavailable, please retry later",
},
{
name: "redis acquire error is service unavailable",
err: errors.New("redis unavailable"),
slotType: "user",
wantStatus: http.StatusServiceUnavailable,
wantType: "api_error",
wantMessage: "Service temporarily unavailable, please retry later",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, errType, message := concurrencyErrorResponse(tt.err, tt.slotType)
require.Equal(t, tt.wantStatus, status)
require.Equal(t, tt.wantType, errType)
require.Equal(t, tt.wantMessage, message)
})
}
}

View File

@ -185,6 +185,7 @@ type SystemSettings struct {
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
OpenAIAllowClaudeCodeCodexPlugin bool `json:"openai_allow_claude_code_codex_plugin"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`

View File

@ -535,7 +535,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
@ -965,7 +965,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
@ -1531,10 +1531,10 @@ func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, su
return min
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
// handleConcurrencyError handles concurrency-related acquire errors.
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
status, errType, message := concurrencyErrorResponse(err, slotType)
h.handleStreamingAwareError(c, status, errType, message, streamStarted)
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
@ -2138,10 +2138,11 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
)
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
func (h *GatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return

View File

@ -292,7 +292,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -267,7 +267,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -336,6 +336,9 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
for {
select {
case <-ctx.Done():
if parentErr := c.Request.Context().Err(); parentErr != nil {
return nil, parentErr
}
return nil, &ConcurrencyError{
SlotType: slotType,
IsTimeout: true,

View File

@ -280,6 +280,25 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
})
}
func TestWaitForSlotWithPingTimeout_ParentContextCanceled(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
reqCtx, cancel := context.WithCancel(c.Request.Context())
c.Request = c.Request.WithContext(reqCtx)
cancel()
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.Nil(t, release)
require.ErrorIs(t, err, context.Canceled)
var cErr *ConcurrencyError
require.False(t, errors.As(err, &cErr))
}
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
errCache := &helperConcurrencyCacheStubWithError{
err: errors.New("redis unavailable"),

View File

@ -528,7 +528,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
QuotaPlatform: quotaPlatform,

View File

@ -127,7 +127,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
for {
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
c.Request.Context(),
apiKey.GroupID,
"",
@ -135,6 +135,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
service.OpenAIEndpointCapabilityChatCompletions,
false,
)
if err != nil {
@ -273,7 +274,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -107,7 +107,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
routingStart := time.Now()
for {
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
selection, _, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
c.Request.Context(),
apiKey.GroupID,
"",
@ -115,6 +115,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportHTTPSSE,
service.OpenAIEndpointCapabilityEmbeddings,
false,
)
if err != nil {
@ -140,13 +141,6 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
return
}
account := selection.Account
if account.Type != service.AccountTypeAPIKey {
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
failedAccountIDs[account.ID] = struct{}{}
continue
}
setOpsSelectedAccount(c, account.ID, account.Platform)
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
@ -220,7 +214,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -12,6 +12,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -46,6 +47,31 @@ func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedM
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
func usageRecordContext(parent context.Context, base context.Context) context.Context {
if base == nil {
base = context.Background()
}
if parent == nil {
return base
}
if clientRequestID, _ := parent.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
base = context.WithValue(base, ctxkey.ClientRequestID, strings.TrimSpace(clientRequestID))
}
if requestID, _ := parent.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
base = context.WithValue(base, ctxkey.RequestID, strings.TrimSpace(requestID))
}
return base
}
func wrapUsageRecordTaskContext(parent context.Context, task service.UsageRecordTask) service.UsageRecordTask {
if task == nil {
return nil
}
return func(ctx context.Context) {
task(usageRecordContext(parent, ctx))
}
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@ -266,7 +292,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
c.Request.Context(),
apiKey.GroupID,
previousResponseID,
@ -274,6 +300,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
service.OpenAIEndpointCapabilityChatCompletions,
requireCompact,
)
if err != nil {
@ -437,7 +464,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -675,7 +702,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel = effectiveMappedModel
}
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
@ -683,6 +710,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
service.OpenAIEndpointCapabilityChatCompletions,
false,
)
if err != nil {
@ -821,7 +849,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -1273,7 +1301,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
for {
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
ctx,
apiKey.GroupID,
previousResponseID,
@ -1281,6 +1309,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
service.OpenAIEndpointCapabilityChatCompletions,
false,
)
if err != nil {
@ -1424,7 +1453,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
@ -1609,10 +1638,11 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
@ -1631,18 +1661,19 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
task(ctx)
}
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(parent context.Context, result *service.OpenAIForwardResult, task service.UsageRecordTask) {
if result != nil && result.ImageCount > 0 {
h.submitMandatoryUsageRecordTask(task)
h.submitMandatoryUsageRecordTask(parent, task)
return
}
h.submitUsageRecordTask(task)
h.submitUsageRecordTask(parent, task)
}
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
if task == nil {
return
}
task = wrapUsageRecordTaskContext(parent, task)
if h.usageRecordWorkerPool != nil {
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
return
@ -1685,10 +1716,10 @@ func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, stream
return nil, false
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
// handleConcurrencyError handles concurrency-related acquire errors.
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
status, errType, message := concurrencyErrorResponse(err, slotType)
h.handleStreamingAwareError(c, status, errType, message, streamStarted)
}
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {

View File

@ -867,8 +867,11 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
}
logs := repo.logSnapshot()
require.Len(t, logs, 1)
var logs []service.ContentModerationLog
require.Eventually(t, func() bool {
logs = repo.logSnapshot()
return len(logs) == 1
}, time.Second, 10*time.Millisecond)
require.True(t, logs[0].Flagged)
require.Equal(t, service.ContentModerationActionBlock, logs[0].Action)
require.Equal(t, "bad prompt", logs[0].InputExcerpt)

View File

@ -0,0 +1,41 @@
package handler
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-request-123")
parent = context.WithValue(parent, ctxkey.RequestID, "request-456")
var gotClientRequestID string
var gotRequestID string
h := &GatewayHandler{}
h.submitUsageRecordTask(parent, func(ctx context.Context) {
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
})
require.Equal(t, "client-request-123", gotClientRequestID)
require.Equal(t, "request-456", gotRequestID)
}
func TestOpenAISubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-request-123")
parent = context.WithValue(parent, ctxkey.RequestID, "openai-request-456")
var gotClientRequestID string
var gotRequestID string
h := &OpenAIGatewayHandler{}
h.submitUsageRecordTask(parent, func(ctx context.Context) {
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
})
require.Equal(t, "openai-client-request-123", gotClientRequestID)
require.Equal(t, "openai-request-456", gotRequestID)
}

View File

@ -311,7 +311,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if result != nil {
upstreamModel = result.UpstreamModel
}
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
h.submitMandatoryUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,

View File

@ -29,7 +29,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
close(done)
})
@ -44,7 +44,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
@ -57,7 +57,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
h.submitUsageRecordTask(context.Background(), nil)
})
}
@ -66,12 +66,12 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *t
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
@ -82,7 +82,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
close(done)
})
@ -97,7 +97,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
@ -110,7 +110,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
h.submitUsageRecordTask(context.Background(), nil)
})
}
@ -119,12 +119,12 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
@ -152,7 +152,7 @@ func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallb
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
h.submitMandatoryUsageRecordTask(context.Background(), func(ctx context.Context) {
called.Store(true)
})
close(release)
@ -182,7 +182,7 @@ func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandator
pool.Submit(func(ctx context.Context) {})
var called atomic.Bool
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
h.submitOpenAIUsageRecordTask(context.Background(), &service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
called.Store(true)
})
close(release)

View File

@ -155,6 +155,7 @@ var claudeModels = []modelDef{
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"},
{ID: "claude-opus-4-8", DisplayName: "Claude Opus 4.8", CreatedAt: "2026-05-29T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}

View File

@ -12,6 +12,7 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
}
requiredIDs := []string{
"claude-opus-4-8",
"claude-opus-4-6-thinking",
"gemini-2.5-flash-image",
"gemini-2.5-flash-image-preview",

View File

@ -204,6 +204,8 @@ type modelInfo struct {
// 只有在此映射表中的模型才会注入身份提示词
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
var modelInfoMap = map[string]modelInfo{
"claude-opus-4-8": {DisplayName: "Claude Opus 4.8", CanonicalID: "claude-opus-4-8"},
"claude-opus-4-7": {DisplayName: "Claude Opus 4.7", CanonicalID: "claude-opus-4-7"},
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
@ -587,7 +589,8 @@ func maxOutputTokensLimit(model string) int {
func isAntigravityOpusHighTierModel(model string) bool {
lower := strings.ToLower(model)
return strings.HasPrefix(lower, "claude-opus-4-6") ||
strings.HasPrefix(lower, "claude-opus-4-7")
strings.HasPrefix(lower, "claude-opus-4-7") ||
strings.HasPrefix(lower, "claude-opus-4-8")
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {

View File

@ -1597,3 +1597,139 @@ func TestAnthropicToResponses_TemperatureStrippedForAllGpt5Variants(t *testing.T
})
}
}
// ---------------------------------------------------------------------------
// AnthropicToResponsesResponse: Anthropic input_tokens excludes cached tokens
// while OpenAI Responses input_tokens is the total including cached tokens.
// ---------------------------------------------------------------------------
func TestAnthropicToResponsesResponse_CacheTokensUseOpenAIInputSemantics(t *testing.T) {
resp := &AnthropicResponse{
ID: "msg_cache",
Model: "claude-sonnet-4-5-20250929",
Content: []AnthropicContentBlock{
{Type: "text", Text: "ok"},
},
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 3318,
OutputTokens: 123,
CacheReadInputTokens: 50688,
CacheCreationInputTokens: 200,
},
}
out := AnthropicToResponsesResponse(resp)
require.NotNil(t, out.Usage)
// 3318 (uncached) + 50688 (read) + 200 (creation) = 54206
assert.Equal(t, 54206, out.Usage.InputTokens)
assert.Equal(t, 123, out.Usage.OutputTokens)
assert.Equal(t, 54329, out.Usage.TotalTokens)
require.NotNil(t, out.Usage.InputTokensDetails)
assert.Equal(t, 50688, out.Usage.InputTokensDetails.CachedTokens)
}
func TestAnthropicToResponsesResponse_NoCacheTokens(t *testing.T) {
resp := &AnthropicResponse{
ID: "msg_nocache",
Model: "claude-sonnet-4-5-20250929",
Content: []AnthropicContentBlock{
{Type: "text", Text: "ok"},
},
StopReason: "end_turn",
Usage: AnthropicUsage{
InputTokens: 100,
OutputTokens: 50,
},
}
out := AnthropicToResponsesResponse(resp)
require.NotNil(t, out.Usage)
assert.Equal(t, 100, out.Usage.InputTokens)
assert.Equal(t, 50, out.Usage.OutputTokens)
assert.Equal(t, 150, out.Usage.TotalTokens)
assert.Nil(t, out.Usage.InputTokensDetails)
}
func TestAnthropicEventToResponses_CacheTokensRoundTripFromMessageStart(t *testing.T) {
state := NewAnthropicEventToResponsesState()
// message_start carries cache fields on the initial Usage object.
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
Type: "message_start",
Message: &AnthropicResponse{
ID: "msg_stream_cache",
Model: "claude-sonnet-4-5-20250929",
Usage: AnthropicUsage{
InputTokens: 12,
CacheReadInputTokens: 9,
CacheCreationInputTokens: 3,
},
},
}, state)
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
Type: "message_delta",
Usage: &AnthropicUsage{
OutputTokens: 7,
},
}, state)
events := AnthropicEventToResponsesEvents(&AnthropicStreamEvent{Type: "message_stop"}, state)
// The terminal response.completed event must include OpenAI-semantic usage.
var completed *ResponsesStreamEvent
for i := range events {
if events[i].Type == "response.completed" {
completed = &events[i]
}
}
require.NotNil(t, completed, "response.completed event must be emitted")
require.NotNil(t, completed.Response)
require.NotNil(t, completed.Response.Usage)
// 12 (uncached) + 9 (read) + 3 (creation) = 24
assert.Equal(t, 24, completed.Response.Usage.InputTokens)
assert.Equal(t, 7, completed.Response.Usage.OutputTokens)
assert.Equal(t, 31, completed.Response.Usage.TotalTokens)
require.NotNil(t, completed.Response.Usage.InputTokensDetails)
assert.Equal(t, 9, completed.Response.Usage.InputTokensDetails.CachedTokens)
}
func TestAnthropicEventToResponses_CacheTokensFromMessageDelta(t *testing.T) {
state := NewAnthropicEventToResponsesState()
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
Type: "message_start",
Message: &AnthropicResponse{
ID: "msg_delta_cache",
Model: "claude-sonnet-4-5-20250929",
Usage: AnthropicUsage{InputTokens: 20},
},
}, state)
// Some upstreams only emit cache fields on the final message_delta.
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
Type: "message_delta",
Usage: &AnthropicUsage{
OutputTokens: 8,
CacheReadInputTokens: 11,
CacheCreationInputTokens: 4,
},
}, state)
events := AnthropicEventToResponsesEvents(&AnthropicStreamEvent{Type: "message_stop"}, state)
var completed *ResponsesStreamEvent
for i := range events {
if events[i].Type == "response.completed" {
completed = &events[i]
}
}
require.NotNil(t, completed)
require.NotNil(t, completed.Response.Usage)
// 20 (uncached) + 11 (read) + 4 (creation) = 35
assert.Equal(t, 35, completed.Response.Usage.InputTokens)
assert.Equal(t, 8, completed.Response.Usage.OutputTokens)
require.NotNil(t, completed.Response.Usage.InputTokensDetails)
assert.Equal(t, 11, completed.Response.Usage.InputTokensDetails.CachedTokens)
}

View File

@ -95,10 +95,16 @@ func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
}
// Usage
// Anthropic's input_tokens excludes cache_read/cache_creation, while OpenAI
// Responses' input_tokens is the total including cached tokens. Add them back
// when converting so downstream consumers see OpenAI semantics.
totalInputTokens := resp.Usage.InputTokens +
resp.Usage.CacheReadInputTokens +
resp.Usage.CacheCreationInputTokens
out.Usage = &ResponsesUsage{
InputTokens: resp.Usage.InputTokens,
InputTokens: totalInputTokens,
OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
TotalTokens: totalInputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.CacheReadInputTokens > 0 {
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
@ -150,10 +156,13 @@ type AnthropicEventToResponsesState struct {
CurrentCallID string
CurrentName string
// Usage from message_delta
InputTokens int
OutputTokens int
CacheReadInputTokens int
// Usage from message_start / message_delta. InputTokens here follows
// Anthropic semantics (excludes cached tokens); they are added back when
// emitting the OpenAI Responses usage.
InputTokens int
OutputTokens int
CacheReadInputTokens int
CacheCreationInputTokens int
}
// NewAnthropicEventToResponsesState returns an initialised stream state.
@ -225,6 +234,12 @@ func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEven
if evt.Message.Usage.InputTokens > 0 {
state.InputTokens = evt.Message.Usage.InputTokens
}
if evt.Message.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Message.Usage.CacheReadInputTokens
}
if evt.Message.Usage.CacheCreationInputTokens > 0 {
state.CacheCreationInputTokens = evt.Message.Usage.CacheCreationInputTokens
}
}
if state.CreatedSent {
@ -392,9 +407,15 @@ func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEven
// Update usage
if evt.Usage != nil {
state.OutputTokens = evt.Usage.OutputTokens
if evt.Usage.InputTokens > 0 {
state.InputTokens = evt.Usage.InputTokens
}
if evt.Usage.CacheReadInputTokens > 0 {
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
}
if evt.Usage.CacheCreationInputTokens > 0 {
state.CacheCreationInputTokens = evt.Usage.CacheCreationInputTokens
}
}
return nil
@ -472,10 +493,13 @@ func makeResponsesCompletedEvent(
seq := state.SequenceNumber
state.SequenceNumber++
// Anthropic's input_tokens excludes cache_read/cache_creation; add them
// back to match OpenAI Responses semantics where input_tokens is the total.
totalInputTokens := state.InputTokens + state.CacheReadInputTokens + state.CacheCreationInputTokens
usage := &ResponsesUsage{
InputTokens: state.InputTokens,
InputTokens: totalInputTokens,
OutputTokens: state.OutputTokens,
TotalTokens: state.InputTokens + state.OutputTokens,
TotalTokens: totalInputTokens + state.OutputTokens,
}
if state.CacheReadInputTokens > 0 {
usage.InputTokensDetails = &ResponsesInputTokensDetails{

View File

@ -663,6 +663,115 @@ func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesToChatCompletions_ReasoningTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_reasoning",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "ping"}},
},
},
Usage: &ResponsesUsage{
InputTokens: 24,
OutputTokens: 33,
TotalTokens: 57,
OutputTokensDetails: &ResponsesOutputTokensDetails{
ReasoningTokens: 32,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-5.5")
require.NotNil(t, chat.Usage)
assert.Equal(t, 33, chat.Usage.CompletionTokens)
require.NotNil(t, chat.Usage.CompletionTokensDetails)
assert.Equal(t, 32, chat.Usage.CompletionTokensDetails.ReasoningTokens)
}
func TestResponsesToChatCompletions_AllTokenDetailsPassThrough(t *testing.T) {
// Covers the full OpenAI CompletionUsage detail field set so future audio
// and prediction-outputs responses propagate without further changes.
resp := &ResponsesResponse{
ID: "resp_full_details",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "x"}},
},
},
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 60,
AudioTokens: 4,
},
OutputTokensDetails: &ResponsesOutputTokensDetails{
ReasoningTokens: 30,
AudioTokens: 2,
AcceptedPredictionTokens: 10,
RejectedPredictionTokens: 3,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-5.5")
require.NotNil(t, chat.Usage)
require.NotNil(t, chat.Usage.PromptTokensDetails)
assert.Equal(t, 60, chat.Usage.PromptTokensDetails.CachedTokens)
assert.Equal(t, 4, chat.Usage.PromptTokensDetails.AudioTokens)
require.NotNil(t, chat.Usage.CompletionTokensDetails)
assert.Equal(t, 30, chat.Usage.CompletionTokensDetails.ReasoningTokens)
assert.Equal(t, 2, chat.Usage.CompletionTokensDetails.AudioTokens)
assert.Equal(t, 10, chat.Usage.CompletionTokensDetails.AcceptedPredictionTokens)
assert.Equal(t, 3, chat.Usage.CompletionTokensDetails.RejectedPredictionTokens)
raw, err := json.Marshal(chat.Usage)
require.NoError(t, err)
assert.Contains(t, string(raw), `"prompt_tokens_details"`)
assert.Contains(t, string(raw), `"completion_tokens_details"`)
assert.Contains(t, string(raw), `"reasoning_tokens":30`)
assert.Contains(t, string(raw), `"accepted_prediction_tokens":10`)
}
func TestResponsesToChatCompletions_NoReasoningTokensWhenZero(t *testing.T) {
// Non-reasoning models do not return reasoning_tokens. The mapping must
// omit completion_tokens_details entirely rather than emitting a zero-valued
// field, so non-reasoning responses stay clean.
resp := &ResponsesResponse{
ID: "resp_no_reasoning",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "hi"}},
},
},
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
OutputTokensDetails: &ResponsesOutputTokensDetails{
ReasoningTokens: 0,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.NotNil(t, chat.Usage)
assert.Nil(t, chat.Usage.CompletionTokensDetails)
raw, err := json.Marshal(chat.Usage)
require.NoError(t, err)
assert.NotContains(t, string(raw), "completion_tokens_details")
assert.NotContains(t, string(raw), "reasoning_tokens")
}
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_ws",
@ -825,6 +934,32 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesEventToChatChunks_CompletedWithReasoningTokens(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-5.5"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 24,
OutputTokens: 33,
TotalTokens: 57,
OutputTokensDetails: &ResponsesOutputTokensDetails{
ReasoningTokens: 32,
},
},
},
}, state)
require.Len(t, chunks, 2)
require.NotNil(t, chunks[1].Usage)
require.NotNil(t, chunks[1].Usage.CompletionTokensDetails)
assert.Equal(t, 32, chunks[1].Usage.CompletionTokensDetails.ReasoningTokens)
}
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"

View File

@ -81,19 +81,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
FinishReason: finishReason,
}}
if resp.Usage != nil {
usage := &ChatUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
}
}
out.Usage = usage
}
out.Usage = chatUsageFromResponsesUsage(resp.Usage)
return out
}
@ -341,14 +329,48 @@ func chatUsageFromResponsesUsage(u *ResponsesUsage) *ChatUsage {
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
usage.PromptTokensDetails = promptDetailsFromResponses(u.InputTokensDetails)
usage.CompletionTokensDetails = completionDetailsFromResponses(u.OutputTokensDetails)
return usage
}
// promptDetailsFromResponses maps Responses-API input_tokens_details into a
// Chat-Completions prompt_tokens_details. Returns nil when nothing would be
// emitted, so upstreams that do not break down prompt usage stay clean.
func promptDetailsFromResponses(src *ResponsesInputTokensDetails) *ChatTokenDetails {
if src == nil {
return nil
}
if src.CachedTokens == 0 && src.AudioTokens == 0 {
return nil
}
return &ChatTokenDetails{
CachedTokens: src.CachedTokens,
AudioTokens: src.AudioTokens,
}
}
// completionDetailsFromResponses maps Responses-API output_tokens_details
// into a Chat-Completions completion_tokens_details. Mirrors the OpenAI
// official CompletionUsage schema: reasoning_tokens, audio_tokens, and
// the predicted-outputs accepted/rejected counts. Returns nil when nothing
// would be emitted so non-reasoning, non-audio responses stay clean.
func completionDetailsFromResponses(src *ResponsesOutputTokensDetails) *ChatTokenDetails {
if src == nil {
return nil
}
if src.ReasoningTokens == 0 && src.AudioTokens == 0 &&
src.AcceptedPredictionTokens == 0 && src.RejectedPredictionTokens == 0 {
return nil
}
return &ChatTokenDetails{
ReasoningTokens: src.ReasoningTokens,
AudioTokens: src.AudioTokens,
AcceptedPredictionTokens: src.AcceptedPredictionTokens,
RejectedPredictionTokens: src.RejectedPredictionTokens,
}
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,

View File

@ -362,11 +362,15 @@ func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
// ResponsesInputTokensDetails breaks down input token usage.
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
}
// ResponsesOutputTokensDetails breaks down output token usage.
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
}
// ---------------------------------------------------------------------------
@ -517,15 +521,27 @@ type ChatChoice struct {
// ChatUsage holds token counts in Chat Completions format.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *ChatTokenDetails `json:"completion_tokens_details,omitempty"`
}
// ChatTokenDetails provides a breakdown of token usage.
// ChatTokenDetails provides a breakdown of token usage. The same type is
// reused for both prompt_tokens_details and completion_tokens_details;
// unset fields are omitted so each side only emits the fields that apply.
//
// Field set mirrors OpenAI's official CompletionUsage schema:
// - prompt_tokens_details: cached_tokens, audio_tokens
// - completion_tokens_details: reasoning_tokens, audio_tokens,
// accepted_prediction_tokens, rejected_prediction_tokens
type ChatTokenDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
CachedTokens int `json:"cached_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
}
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.

View File

@ -446,6 +446,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.7",
CreatedAt: "2026-04-17T00:00:00Z",
},
{
ID: "claude-opus-4-8",
Type: "model",
DisplayName: "Claude Opus 4.8",
CreatedAt: "2026-05-29T00:00:00Z",
},
{
ID: "claude-sonnet-4-6",
Type: "model",

View File

@ -0,0 +1,78 @@
package openai
import "strings"
// 命名预设 ID。账号侧 codex_cli_only_allowed_clients 只能引用这些预设键,
// 具体匹配规则固化在下方 registry 中,配置只能「选择启用哪些预设」、不能自定义规则,
// 以防该白名单退化为可任意放宽的后门。
const (
// AllowedClientClaudeCode 对应 Claude Code CLI 的 codex 插件。
AllowedClientClaudeCode = "claude_code"
)
// AllowedClientEntry 描述一个被额外放行的非官方 Codex 客户端签名。
// Originator 必须精确等值匹配(归一化后)。
// UAContains 为必填字段:列表为空,或列表中存在任何空白 marker均视为非法配置
// 整体安全失败return false每一项都必须出现在 User-Agent 中。
// 这确保双因子匹配不会因缺失 UA 声明而退化为仅凭可伪造的 originator 单因子放行。
type AllowedClientEntry struct {
Originator string
UAContains []string
}
// allowedClientRegistry 固化各命名预设的签名规则。
//
// Claude Code codex 插件签名来源:插件以 clientInfo.name="Claude Code" 完成 app-server
// initialize 握手codex 据此把 originator 设为 "Claude Code"User-Agent 前缀同样为
// "Claude Code/"(两者同源)。若上游 Claude Code 插件更改 clientInfo.name此处需同步更新。
var allowedClientRegistry = map[string]AllowedClientEntry{
AllowedClientClaudeCode: {
Originator: "Claude Code",
UAContains: []string{"Claude Code/"},
},
}
// IsAllowedClientMatch 判断请求头是否命中给定的额外客户端签名。
// originator 必须精确等值归一化后UAContains 中每一项都必须出现在 UA 中。
// UAContains 为必填:列表为空或含任何空白 marker 均视为非法配置,整体安全失败。
func IsAllowedClientMatch(userAgent, originator string, entry AllowedClientEntry) bool {
wantOriginator := normalizeCodexClientHeader(entry.Originator)
if wantOriginator == "" {
return false
}
if normalizeCodexClientHeader(originator) != wantOriginator {
return false
}
// 预设必须声明 UA 特征:否则将退化为仅凭可伪造的 originator 单因子匹配。
if len(entry.UAContains) == 0 {
return false
}
ua := normalizeCodexClientHeader(userAgent)
for _, marker := range entry.UAContains {
normalizedMarker := normalizeCodexClientHeader(marker)
if normalizedMarker == "" {
// 空白 marker 让该项失去校验能力,会让双因子退化为仅 originator
// 单因子;视为非法配置,安全失败。
return false
}
if !strings.Contains(ua, normalizedMarker) {
return false
}
}
return true
}
// MatchAllowedClients 判断请求头是否命中 clientIDs 引用的任一预设签名。
// 未知预设 ID 会被忽略;空列表恒不放行(默认拒绝)。
func MatchAllowedClients(userAgent, originator string, clientIDs []string) bool {
for _, id := range clientIDs {
entry, ok := allowedClientRegistry[normalizeCodexClientHeader(id)]
if !ok {
continue
}
if IsAllowedClientMatch(userAgent, originator, entry) {
return true
}
}
return false
}

View File

@ -0,0 +1,95 @@
package openai
import "testing"
// 真实的 Claude Code codex 插件请求头originator 与 UA 前缀同源于 clientInfo.name="Claude Code"。
const (
testClaudeCodeOriginator = "Claude Code"
testClaudeCodeUserAgent = "Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)"
)
func TestIsAllowedClientMatch(t *testing.T) {
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{"Claude Code/"}}
tests := []struct {
name string
ua string
originator string
want bool
}{
{name: "真实签名命中", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, want: true},
{name: "大小写不敏感", ua: "claude code/0.5.0 (macos)", originator: "claude code", want: true},
{name: "originator 两侧空白被裁剪", ua: testClaudeCodeUserAgent, originator: " Claude Code ", want: true},
{name: "originator 非精确(带后缀)不命中", ua: testClaudeCodeUserAgent, originator: "Claude Code Extra", want: false},
{name: "originator 为空不命中", ua: testClaudeCodeUserAgent, originator: "", want: false},
{name: "originator 是官方 codex 不命中", ua: testClaudeCodeUserAgent, originator: "codex_cli_rs", want: false},
{name: "UA 缺少 Claude Code/ 标记不命中", ua: "curl/8.0", originator: testClaudeCodeOriginator, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsAllowedClientMatch(tt.ua, tt.originator, entry); got != tt.want {
t.Fatalf("IsAllowedClientMatch(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
}
})
}
}
func TestIsAllowedClientMatch_EmptyOriginatorEntryNeverMatches(t *testing.T) {
// registry 条目若没有配置 Originator绝不放行避免成为宽松后门。
entry := AllowedClientEntry{Originator: "", UAContains: []string{"Claude Code/"}}
if IsAllowedClientMatch(testClaudeCodeUserAgent, "", entry) {
t.Fatal("空 Originator 的条目不应匹配任何请求")
}
}
func TestIsAllowedClientMatch_EmptyUAContainsNeverMatches(t *testing.T) {
// 预设必须声明 UA 特征,否则退化为仅凭可伪造的 originator 单因子匹配,绝不放行。
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: nil}
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
t.Fatal("未声明 UA 特征的预设不应匹配,避免退化为单因子 originator 匹配")
}
}
func TestIsAllowedClientMatch_WhitespaceUAMarkerNeverMatches(t *testing.T) {
// 全空白 marker 归一化后为空,若被跳过则退化为仅 originator 单因子;
// 任何空白 marker 视为非法预设配置,必须安全失败。
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{" "}}
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
t.Fatal("UAContains 含全空白 marker 不应匹配,避免退化为单因子 originator 匹配")
}
}
func TestIsAllowedClientMatch_MixedEmptyUAMarkerNeverMatches(t *testing.T) {
// 即便 UAContains 含一个真实 marker只要其中混入任何空白 marker 也视为非法配置;
// 防止维护者只为对齐凑数而插入空字符串。
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{"", "Claude Code/"}}
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
t.Fatal("UAContains 混入空白 marker 不应匹配")
}
}
func TestMatchAllowedClients(t *testing.T) {
tests := []struct {
name string
ua string
originator string
clientIDs []string
want bool
}{
{name: "claude_code 预设命中真实签名", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{AllowedClientClaudeCode}, want: true},
{name: "claude_code 预设 + 伪造 originator 不命中", ua: testClaudeCodeUserAgent, originator: "my_client", clientIDs: []string{AllowedClientClaudeCode}, want: false},
{name: "空列表不放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: nil, want: false},
{name: "未知预设 ID 不放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{"unknown_client"}, want: false},
{name: "ID 大小写/空白容错", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{" Claude_Code "}, want: true},
{name: "多预设任一命中即放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{"unknown_client", AllowedClientClaudeCode}, want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := MatchAllowedClients(tt.ua, tt.originator, tt.clientIDs); got != tt.want {
t.Fatalf("MatchAllowedClients(%q, %q, %v) = %v, want %v", tt.ua, tt.originator, tt.clientIDs, got, tt.want)
}
})
}
}

View File

@ -548,12 +548,17 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
"openai_ws_force_http",
"openai_responses_mode",
"openai_responses_supported",
// model_rate_limits 必须进入调度快照SetModelRateLimit 写入的模型级冷却
// 时间戳accounts.extra.model_rate_limits.<modelKey>.rate_limit_reset_at
// 是 isAccountSchedulableForModelSelection/IsSchedulableForModelWithContext
// 过滤候选账号的唯一依据。缺失会导致已限流账号被反复选中,触发 failover 切号环。
// 与 service.modelRateLimitsKey 常量保持字面量一致。
"model_rate_limits",
"codex_5h_used_percent",
"codex_7d_used_percent",
"codex_5h_reset_at",
"codex_7d_reset_at",
"codex_5h_reset_after_seconds",
"codex_7d_reset_after_seconds",
"codex_usage_updated_at",
"auto_pause_5h_threshold",
"auto_pause_7d_threshold",
"auto_pause_5h_disabled",
"auto_pause_7d_disabled",
}
filtered := make(map[string]any)
for _, key := range keys {

View File

@ -100,3 +100,36 @@ func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
require.Nil(t, got.Groups)
}
func TestBuildSchedulerMetadataAccount_KeepsQuotaAutoPauseFields(t *testing.T) {
account := service.Account{
ID: 88,
Extra: map[string]any{
"codex_5h_used_percent": 12.34,
"codex_7d_used_percent": 56.78,
"codex_5h_reset_at": "2026-05-29T10:00:00Z",
"codex_7d_reset_at": "2026-06-01T10:00:00Z",
"codex_5h_reset_after_seconds": 300,
"codex_7d_reset_after_seconds": 600,
"codex_usage_updated_at": "2026-05-29T09:00:00Z",
"auto_pause_5h_threshold": 0.95,
"auto_pause_7d_threshold": 0.96,
"auto_pause_5h_disabled": true,
"auto_pause_7d_disabled": false,
},
}
got := buildSchedulerMetadataAccount(account)
require.Equal(t, 12.34, got.Extra["codex_5h_used_percent"])
require.Equal(t, 56.78, got.Extra["codex_7d_used_percent"])
require.Equal(t, "2026-05-29T10:00:00Z", got.Extra["codex_5h_reset_at"])
require.Equal(t, "2026-06-01T10:00:00Z", got.Extra["codex_7d_reset_at"])
require.Equal(t, 300, got.Extra["codex_5h_reset_after_seconds"])
require.Equal(t, 600, got.Extra["codex_7d_reset_after_seconds"])
require.Equal(t, "2026-05-29T09:00:00Z", got.Extra["codex_usage_updated_at"])
require.Equal(t, 0.95, got.Extra["auto_pause_5h_threshold"])
require.Equal(t, 0.96, got.Extra["auto_pause_7d_threshold"])
require.Equal(t, true, got.Extra["auto_pause_5h_disabled"])
require.Equal(t, false, got.Extra["auto_pause_7d_disabled"])
}

View File

@ -843,6 +843,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_codex_user_agent": "",
"openai_allow_claude_code_codex_plugin": false,
"openai_fast_policy_settings": {
"rules": []
},
@ -1079,6 +1080,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_codex_user_agent": "",
"openai_allow_claude_code_codex_plugin": false,
"openai_fast_policy_settings": {
"rules": []
},

View File

@ -11,6 +11,8 @@ import (
"go.uber.org/zap"
)
const clientRequestIDHeader = "X-Client-Request-ID"
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
//
// This is used by the Ops monitoring module for end-to-end request correlation.
@ -21,12 +23,14 @@ func ClientRequestID() gin.HandlerFunc {
return
}
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
if v, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(v) != "" {
c.Header(clientRequestIDHeader, strings.TrimSpace(v))
c.Next()
return
}
id := uuid.New().String()
c.Header(clientRequestIDHeader, id)
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
ctx = logger.IntoContext(ctx, requestLogger)

View File

@ -0,0 +1,50 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestClientRequestIDGeneratesAndExposesID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ClientRequestID())
router.GET("/", func(c *gin.Context) {
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
c.String(http.StatusOK, value)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.NotEmpty(t, w.Body.String())
require.Equal(t, w.Body.String(), w.Header().Get(clientRequestIDHeader))
}
func TestClientRequestIDPreservesExistingContextID(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ClientRequestID())
router.GET("/", func(c *gin.Context) {
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
c.String(http.StatusOK, value)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "existing-client-request-id"))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "existing-client-request-id", w.Body.String())
require.Equal(t, "existing-client-request-id", w.Header().Get(clientRequestIDHeader))
}

View File

@ -66,6 +66,15 @@ type Account struct {
modelMappingCacheRawSig uint64
}
type OpenAIEndpointCapability string
const (
OpenAIEndpointCapabilityChatCompletions OpenAIEndpointCapability = "chat_completions"
OpenAIEndpointCapabilityEmbeddings OpenAIEndpointCapability = "embeddings"
)
const openAIEndpointCapabilitiesCredentialKey = "openai_capabilities"
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
@ -1153,6 +1162,80 @@ func (a *Account) GetOpenAISessionID() string {
return strings.TrimSpace(a.GetExtraString("openai_session_id"))
}
func (a *Account) SupportsOpenAIEndpointCapability(capability OpenAIEndpointCapability) bool {
if a == nil {
return false
}
if capability == "" {
return true
}
if !a.IsOpenAI() {
return false
}
switch capability {
case OpenAIEndpointCapabilityChatCompletions:
case OpenAIEndpointCapabilityEmbeddings:
if a.Type != AccountTypeAPIKey {
return false
}
default:
return false
}
configured, found := a.openAIEndpointCapabilitySet()
if !found {
return true
}
return configured[string(capability)]
}
func (a *Account) openAIEndpointCapabilitySet() (map[string]bool, bool) {
if a == nil || a.Credentials == nil {
return nil, false
}
raw, found := a.Credentials[openAIEndpointCapabilitiesCredentialKey]
if !found || raw == nil {
return nil, false
}
result := make(map[string]bool)
add := func(value string) {
value = strings.ToLower(strings.TrimSpace(value))
if value == "" {
return
}
result[value] = true
}
switch capabilities := raw.(type) {
case []any:
for _, item := range capabilities {
if value, ok := item.(string); ok {
add(value)
}
}
case []string:
for _, value := range capabilities {
add(value)
}
case map[string]any:
for key, value := range capabilities {
enabled, ok := value.(bool)
if ok && enabled {
add(key)
}
}
case map[string]bool:
for key, enabled := range capabilities {
if enabled {
add(key)
}
}
}
return result, true
}
func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
if !a.IsOpenAI() {
return false
@ -1473,6 +1556,38 @@ func (a *Account) IsCodexCLIOnlyEnabled() bool {
return ok && enabled
}
// GetCodexCLIOnlyAllowedClients 返回 codex_cli_only 之上额外放行的命名客户端预设 ID 列表。
// 仅 OpenAI OAuth 账号生效;缺失或类型不符时返回空。预设 ID 的具体匹配规则由
// openai 包的 registry 固化,配置只能引用预设键、不能自定义规则。
func (a *Account) GetCodexCLIOnlyAllowedClients() []string {
if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil {
return nil
}
raw, ok := a.Extra["codex_cli_only_allowed_clients"]
if !ok || raw == nil {
return nil
}
switch v := raw.(type) {
case []string:
result := make([]string, 0, len(v))
for _, s := range v {
if strings.TrimSpace(s) != "" {
result = append(result, s)
}
}
return result
case []any:
result := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
result = append(result, s)
}
}
return result
}
return nil
}
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int

View File

@ -0,0 +1,68 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_GetCodexCLIOnlyAllowedClients(t *testing.T) {
t.Run("OAuth 账号读取 []any 字符串列表", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
}
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("OAuth 账号读取 []string 列表", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only_allowed_clients": []string{"claude_code"}},
}
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("[]string 跳过空白元素", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only_allowed_clients": []string{"claude_code", "", " "}},
}
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("跳过非字符串与空白元素", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code", 123, "", " "}},
}
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("非 OAuth 账号返回空", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
}
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("Extra 为空返回空", func(t *testing.T) {
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
})
t.Run("字段缺失返回空", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
})
}

View File

@ -4209,6 +4209,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 能力维度 sanitizeAnthropic-compatible 上游透传路径也需要保证 body↔beta header
// 对称。客户端 anthropic-beta header 不含 context-management-2025-06-27 但 body 带
// context_management 时 strip与 Anthropic 直连 / Bedrock / Vertex 路径保持一致。
clientBeta := c.GetHeader("anthropic-beta")
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
body = sanitized
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
@ -4224,7 +4232,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
if v := clientBeta; v != "" {
req.Header.Set("anthropic-beta", v)
}

View File

@ -88,6 +88,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认映射透传 - claude-opus-4-8",
requestedModel: "claude-opus-4-8",
accountMapping: nil,
expected: "claude-opus-4-8",
},
{
name: "默认映射透传 - claude-opus-4-7",
requestedModel: "claude-opus-4-7",
accountMapping: nil,
expected: "claude-opus-4-7",
},
{
name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6-thinking",
@ -210,6 +222,7 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射(有明确前缀映射)
{"可映射 - claude-opus-4-8", "claude-opus-4-8", true},
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传claude 和 gemini 前缀)

View File

@ -174,6 +174,7 @@ func TestIsBedrockClaude45OrNewer(t *testing.T) {
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-opus-4-8-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
@ -511,6 +512,20 @@ func TestResolveBedrockModelID(t *testing.T) {
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("default opus 4.8 mapping uses regional Bedrock model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-8")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-opus-4-8-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
@ -714,6 +729,7 @@ func TestIsBedrockOpus47OrNewer(t *testing.T) {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-8-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
{"us.anthropic.claude-opus-4-6-v1", false},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", false},
@ -886,10 +902,12 @@ func TestIsBedrockOpus47OrNewer_EdgeCases(t *testing.T) {
modelID string
expect bool
}{
{"anthropic.claude-opus-4-8-v1", true},
{"anthropic.claude-opus-4-7-v1", true},
{"us.anthropic.claude-opus-4-7-20270101-v1:0", true},
{"", false},
// Forward() passes parsed.Model (standard names), not Bedrock IDs
{"claude-opus-4-8", true},
{"claude-opus-4-7", true},
{"claude-opus-4-6", false},
{"claude-sonnet-4-7", false},

View File

@ -432,6 +432,9 @@ const (
// 当客户端 UA 被识别为浏览器Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
// SettingKeyOpenAIAllowClaudeCodeCodexPlugin 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认 false
// 仅在账号 codex_cli_only 开启时生效;开启后无需逐账号配置 codex_cli_only_allowed_clients。
SettingKeyOpenAIAllowClaudeCodeCodexPlugin = "openai_allow_claude_code_codex_plugin"
// 余额不足提醒
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关

View File

@ -476,6 +476,66 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFie
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFiltersGenerationFields(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"tool","input_schema":{"type":"object"}}],"temperature":0.7,"top_p":0.9,"top_k":40,"stream":true,"stop_sequences":["END"],"max_tokens":1024,"thinking":{"type":"enabled","budget_tokens":5000}}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-sonnet-4-20250514",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 302,
Name: "count-token-filter-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
sentBody := upstream.lastBody
require.False(t, gjson.GetBytes(sentBody, "temperature").Exists())
require.False(t, gjson.GetBytes(sentBody, "top_p").Exists())
require.False(t, gjson.GetBytes(sentBody, "top_k").Exists())
require.False(t, gjson.GetBytes(sentBody, "stream").Exists())
require.False(t, gjson.GetBytes(sentBody, "stop_sequences").Exists())
require.Equal(t, "claude-sonnet-4-20250514", gjson.GetBytes(sentBody, "model").String())
require.Equal(t, "sys", gjson.GetBytes(sentBody, "system.0.text").String())
require.Equal(t, "hello", gjson.GetBytes(sentBody, "messages.0.content").String())
require.Equal(t, "tool", gjson.GetBytes(sentBody, "tools.0.name").String())
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int())
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String())
}
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
// 确保空模型名不会触发映射逻辑
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {

View File

@ -66,3 +66,67 @@ func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
require.NoError(t, err)
return body
}
// Vertex 路径回归保护:同样需要
// body↔beta header 能力维度对称。客户端 header 不带 context-management beta
// 但 body 带 context_management 字段 → Vertex builder 必须 strip 字段,与 Anthropic
// 直连 / Bedrock 路径保持一致。
func TestGatewayService_BuildAnthropicVertexServiceAccount_StripsContextManagementWhenBetaMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 客户端 header 只带 interleaved-thinking不带 context-management-2025-06-27
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{
ID: 302, Platform: PlatformAnthropic, Type: AccountTypeServiceAccount,
Credentials: map[string]any{"project_id": "vertex-proj", "location": "us-east5"},
}
// body 带了 context_management 字段(客户端透传 / normalize 补齐 / mimicry 注入等场景都可能导致)
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[{"role":"user","content":"hi"}]}`)
svc := &GatewayService{}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"vertex-token", "service_account", "claude-haiku-4-5@20251001", false, false,
)
require.NoError(t, err)
got := readRequestBodyForTest(t, req)
require.False(t, gjson.GetBytes(got, "context_management").Exists(),
"Vertex 路径下客户端 header 缺 context-management beta 时,必须 strip body 同名字段")
// header 对称断言:覆盖未来某人在 Vertex builder 里加入与 sanitize 不一致的 header 处理。
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.False(t, anthropicBetaTokensContains(outBeta, "context-management-2025-06-27"),
"与 body 对称outgoing anthropic-beta header 也不含 context-management beta")
}
// Vertex 路径反面:客户端 header 含 context-management beta 时保留字段。
func TestGatewayService_BuildAnthropicVertexServiceAccount_PreservesContextManagementWhenBetaPresent(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14,context-management-2025-06-27")
account := &Account{
ID: 303, Platform: PlatformAnthropic, Type: AccountTypeServiceAccount,
Credentials: map[string]any{"project_id": "vertex-proj", "location": "us-east5"},
}
body := []byte(`{"model":"claude-sonnet-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"vertex-token", "service_account", "claude-sonnet-4-6@20260218", false, false,
)
require.NoError(t, err)
got := readRequestBodyForTest(t, req)
require.True(t, gjson.GetBytes(got, "context_management").Exists(),
"Vertex + 客户端 header 包含 context-management beta 时字段必须保留")
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.True(t, anthropicBetaTokensContains(outBeta, "context-management-2025-06-27"),
"与 body 对称outgoing anthropic-beta header 同步含 context-management beta")
}

View File

@ -0,0 +1,667 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// ============================================================================
// 背景
// ============================================================================
//
// Anthropic 上游对 body.context_management 字段实施 Pydantic schema 校验:
// 当且仅当 anthropic-beta header 含 context-management-2025-06-27 时接受。
// 否则报:
// "context_management: Extra inputs are not permitted"
//
// 本仓采用能力维度对称约束(与 Bedrock 路径的 sanitizeBedrockFieldsForBetaTokens
// 对称):在所有 Anthropic 直连出口,按最终 anthropic-beta header 是否含上述 token
// 决定 body 是否保留同名字段。
//
// 本文件覆盖:
// 1) sanitizeAnthropicBodyForBetaTokens 纯函数
// 2) anthropicBetaTokensContains 解析辅助函数
// 3) computeFinalAnthropicBeta / computeFinalCountTokensAnthropicBeta 各路径
// 4) normalizeClaudeOAuthRequestBody 的 context_management 补齐行为(不再按 model 短路)
// ============================================================================
// anthropicBetaTokensContains
// ============================================================================
func TestAnthropicBetaTokensContains_EmptyInputs(t *testing.T) {
require.False(t, anthropicBetaTokensContains("", "context-management-2025-06-27"))
require.False(t, anthropicBetaTokensContains("oauth-2025-04-20", ""))
}
func TestAnthropicBetaTokensContains_SingleToken(t *testing.T) {
require.True(t, anthropicBetaTokensContains("context-management-2025-06-27", "context-management-2025-06-27"))
}
func TestAnthropicBetaTokensContains_MultiTokenComma(t *testing.T) {
header := "oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14"
require.True(t, anthropicBetaTokensContains(header, "context-management-2025-06-27"))
require.True(t, anthropicBetaTokensContains(header, "oauth-2025-04-20"))
require.False(t, anthropicBetaTokensContains(header, "fast-mode-2026-02-01"))
}
func TestAnthropicBetaTokensContains_ToleratesWhitespace(t *testing.T) {
header := "oauth-2025-04-20 , context-management-2025-06-27 , interleaved-thinking-2025-05-14"
require.True(t, anthropicBetaTokensContains(header, "context-management-2025-06-27"))
}
func TestAnthropicBetaTokensContains_SubstringNotMatched(t *testing.T) {
// 严格 token 比较,不应被子串误匹配
require.False(t, anthropicBetaTokensContains("context-management-2025-06-27-rev2", "context-management-2025-06-27"),
"必须按 token 边界匹配,不允许 prefix 子串误命中")
}
// ============================================================================
// sanitizeAnthropicBodyForBetaTokens
// ============================================================================
func TestSanitizeAnthropicBodyForBetaTokens_NoFieldNoChange(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","messages":[]}`)
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "oauth-2025-04-20")
require.False(t, changed)
require.Equal(t, string(body), string(out))
}
func TestSanitizeAnthropicBodyForBetaTokens_FieldKeptWhenBetaPresent(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-7","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
out, changed := sanitizeAnthropicBodyForBetaTokens(body,
"oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14")
require.False(t, changed)
require.True(t, gjson.GetBytes(out, "context_management").Exists())
require.Equal(t, "clear_thinking_20251015",
gjson.GetBytes(out, "context_management.edits.0.type").String())
}
func TestSanitizeAnthropicBodyForBetaTokens_FieldStrippedWhenBetaMissing(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "oauth-2025-04-20,interleaved-thinking-2025-05-14")
require.True(t, changed)
require.False(t, gjson.GetBytes(out, "context_management").Exists(),
"header 不含 context-management beta 时必须 strip 同名字段")
}
func TestSanitizeAnthropicBodyForBetaTokens_FieldStrippedWhenBetaEmpty(t *testing.T) {
body := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "")
require.True(t, changed)
require.False(t, gjson.GetBytes(out, "context_management").Exists())
}
func TestSanitizeAnthropicBodyForBetaTokens_EmptyBody(t *testing.T) {
out, changed := sanitizeAnthropicBodyForBetaTokens([]byte{}, "")
require.False(t, changed)
require.Empty(t, out)
out, changed = sanitizeAnthropicBodyForBetaTokens(nil, "")
require.False(t, changed)
require.Empty(t, out)
}
// ★ 关键回归断言:能力维度 sanitize 解决了 "真 CC + haiku" 路径的过度删除问题。
// 真实 Claude Code CLI 2.1.87+ 客户端 header 含 context-management beta
// 即使 model 是 haikusanitize 也不应剥离功能字段。
func TestSanitizeAnthropicBodyForBetaTokens_HaikuRealCCClientPreservesField(t *testing.T) {
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
// 真 Claude Code CLI 2.1.87+ 客户端 header 含 context-management beta
clientBeta := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27"
out, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta)
require.False(t, changed,
"真 CC 客户端 header 含 context-management beta 时haiku body 字段必须保留(功能不丢)")
require.True(t, gjson.GetBytes(out, "context_management").Exists())
}
// ============================================================================
// computeFinalAnthropicBeta — 关键路径
// ============================================================================
func newTestGatewayServiceForBeta(injectBetaForAPIKey bool) *GatewayService {
cfg := &config.Config{}
cfg.Gateway.InjectBetaForAPIKey = injectBetaForAPIKey
return &GatewayService{cfg: cfg}
}
func TestComputeFinalAnthropicBeta_OAuthMimic_NonHaiku_IncludesContextManagement(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", http.Header{}, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"OAuth mimic non-haiku 必须注入完整 CC mimicry beta含 context-management-2025-06-27")
require.True(t, anthropicBetaTokensContains(final, claude.BetaOAuth))
require.True(t, anthropicBetaTokensContains(final, claude.BetaClaudeCode))
}
func TestComputeFinalAnthropicBeta_OAuthMimic_Haiku_ExcludesContextManagement(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
require.True(t, ok)
require.False(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"OAuth mimic haiku 仅注入 oauth + interleaved-thinking不含 context-management")
require.True(t, anthropicBetaTokensContains(final, claude.BetaOAuth))
require.True(t, anthropicBetaTokensContains(final, claude.BetaInterleavedThinking))
}
func TestComputeFinalAnthropicBeta_OAuthMimic_IgnoresClientBeta(t *testing.T) {
// mimic 路径下原代码白名单透传被跳过client beta 应被忽略
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "custom-experimental-beta")
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.False(t, strings.Contains(final, "custom-experimental-beta"),
"mimic 路径必须忽略客户端 anthropic-beta header")
}
func TestComputeFinalAnthropicBeta_OAuthTransparent_NonHaiku_PreservesClientContextManagement(t *testing.T) {
// 真 CC 客户端透传:客户端 header 中的 context-management beta 必须保留
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,context-management-2025-06-27")
final, ok := s.computeFinalAnthropicBeta("oauth", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement))
}
func TestComputeFinalAnthropicBeta_OAuthTransparent_Haiku_RealCCPreservesContextManagement(t *testing.T) {
// haiku 透传 + 客户端带 context-management beta → 必须保留
// (能力维度核心场景:避免 model-name 误删客户端透传的功能 beta
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14")
final, ok := s.computeFinalAnthropicBeta("oauth", false, "claude-haiku-4-5", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"真 CC + haiku + 客户端带 context-management beta → 透传必须保留")
}
func TestComputeFinalAnthropicBeta_APIKey_PassesClientBetaThroughDropSet(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "oauth-2025-04-20,custom-beta")
final, ok := s.computeFinalAnthropicBeta("apikey", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, "oauth-2025-04-20"))
require.True(t, anthropicBetaTokensContains(final, "custom-beta"))
}
func TestComputeFinalAnthropicBeta_APIKey_NoClientBetaInjectOff_ShouldNotSet(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
final, ok := s.computeFinalAnthropicBeta("apikey", false, "claude-sonnet-4-6", http.Header{}, []byte(`{}`), nil)
require.False(t, ok, "API-key + 客户端未传 + InjectBetaForAPIKey 关 → 不应主动设置 anthropic-beta")
require.Equal(t, "", final)
}
// ============================================================================
// computeFinalCountTokensAnthropicBeta
// ============================================================================
func TestComputeFinalCountTokensAnthropicBeta_OAuthMimic_AlwaysIncludesContextManagement(t *testing.T) {
// count_tokens 路径下 mimic 不按 haiku 排除:始终注入完整 mimicry beta
s := newTestGatewayServiceForBeta(false)
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", true, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"count_tokens + mimic 即使 haiku 也注入 context-management beta与 messages 不同)")
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
"count_tokens 路径必须含 token-counting beta")
}
// 重构等价性回归:
// 原 main buildCountTokensRequest 在 count_tokens mimic 分支上不跳过白名单透传
// (与 messages mimic 不同incomingBeta 取自客户端透传。重构后必须从 clientHeaders
// 拿同一个值并 merge否则会丢失客户端 beta。
func TestComputeFinalCountTokensAnthropicBeta_OAuthMimic_PreservesClientBeta(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "custom-experimental-beta,context-1m-2025-08-07")
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", true, "claude-haiku-4-5", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, "custom-experimental-beta"),
"count_tokens mimic 不同于 messages mimic原代码会保留客户端透传的 beta")
require.True(t, anthropicBetaTokensContains(final, "context-1m-2025-08-07"),
"客户端透传的其他 beta token 同样需要保留")
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"同时 FullClaudeCodeMimicryBetas 不打折扣")
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
"同时补齐 token-counting beta")
}
// messages mimic 路径反向验证:原代码会跳过白名单透传,
// 客户端 beta 不会进入 mimic 计算。重构后 messages computeFinalAnthropicBeta
// mimic 分支依然不该使用 clientBeta。
func TestComputeFinalAnthropicBeta_OAuthMimic_IgnoresClientBetaExplicit(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "custom-experimental-beta")
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.False(t, anthropicBetaTokensContains(final, "custom-experimental-beta"),
"messages mimic 原代码跳过白名单透传 → 客户端 beta 不进入计算。"+
"与 count_tokens mimic 是不同的设计,不能合并为同一函数。")
}
func TestComputeFinalCountTokensAnthropicBeta_OAuthTransparent_NoClientBetaInjectsDefault(t *testing.T) {
// 真 CC 客户端透传 + 客户端未传 anthropic-beta → 用 CountTokensBetaHeader 兜底
s := newTestGatewayServiceForBeta(false)
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", false, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
require.True(t, ok)
require.Equal(t, claude.CountTokensBetaHeader, final)
// CountTokensBetaHeader 不含 context-management beta
require.False(t, anthropicBetaTokensContains(final, claude.BetaContextManagement))
}
func TestComputeFinalCountTokensAnthropicBeta_OAuthTransparent_AppendsBetaTokenCounting(t *testing.T) {
s := newTestGatewayServiceForBeta(false)
hdr := http.Header{}
hdr.Set("anthropic-beta", "oauth-2025-04-20,context-management-2025-06-27")
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
require.True(t, ok)
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
"客户端未带 token-counting beta 时必须补齐")
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
"客户端带的 context-management beta 必须保留")
}
// ============================================================================
// normalizeClaudeOAuthRequestBody — 回归context_management 补齐恢复原行为
// ============================================================================
//
// 重构后该函数不再按 model 名短路thinking=enabled/adaptive 时补齐 context_management
// 与 model 无关。strip 责任移交 sanitizeAnthropicBodyForBetaTokens
// buildUpstreamRequest 层按最终 beta header 执行)。
func TestNormalizeClaudeOAuthRequestBody_InjectsContextManagement_ThinkingEnabled(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-sonnet-4-6", claudeOAuthNormalizeOptions{})
require.True(t, gjson.GetBytes(out, "context_management").Exists())
require.Equal(t, "clear_thinking_20251015",
gjson.GetBytes(out, "context_management.edits.0.type").String())
}
func TestNormalizeClaudeOAuthRequestBody_InjectsContextManagement_ThinkingAdaptive(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-7","thinking":{"type":"adaptive"},"messages":[]}`)
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-opus-4-7", claudeOAuthNormalizeOptions{})
require.True(t, gjson.GetBytes(out, "context_management").Exists())
}
func TestNormalizeClaudeOAuthRequestBody_HaikuStillInjects_StripDeferredToSanitize(t *testing.T) {
// haiku + thinking=enablednormalize 阶段仍按 CLI mimicry 行为补齐字段;
// strip 由 buildUpstreamRequest 层的 sanitize 兜底(如果 final beta 不含 token
body := []byte(`{"model":"claude-haiku-4-5","thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-haiku-4-5", claudeOAuthNormalizeOptions{})
require.True(t, gjson.GetBytes(out, "context_management").Exists(),
"normalize 不再按 model 名短路strip 责任移交 sanitize 层")
}
func TestNormalizeClaudeOAuthRequestBody_PreservesClientContextManagement(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-7","context_management":{"edits":[{"type":"custom_strategy"}]},"thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-opus-4-7", claudeOAuthNormalizeOptions{})
require.Equal(t, "custom_strategy",
gjson.GetBytes(out, "context_management.edits.0.type").String(),
"客户端透传的 context_management 内容必须原样保留")
}
func TestNormalizeClaudeOAuthRequestBody_NoThinking_NoInject(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-6","messages":[]}`)
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-sonnet-4-6", claudeOAuthNormalizeOptions{})
require.False(t, gjson.GetBytes(out, "context_management").Exists())
}
// ============================================================================
// passthrough 集成测试buildUpstreamRequest-
// AnthropicAPIKeyPassthrough 与 buildCountTokensRequestAnthropicAPIKeyPassthrough
// 路径上 sanitize 是否生效。
// ============================================================================
// passthrough 集成测试不设 base_url避开 validateUpstreamBaseURL 对 cfg.Security 的依赖。
// targetURL 会走默认 claudeAPIURLsanitize 逻辑与 baseURL 是否存在无关。
func newAnthropicAPIKeyPassthroughAccountForBetaTest() *Account {
return &Account{
ID: 501,
Name: "anthropic-apikey-passthrough-ctxmgmt-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "upstream-key",
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
}
func readUpstreamBodyForTest(t *testing.T, req *http.Request) []byte {
t.Helper()
require.NotNil(t, req.Body)
b, err := io.ReadAll(req.Body)
require.NoError(t, err)
return b
}
func TestBuildUpstreamRequestAnthropicAPIKeyPassthrough_StripsContextManagementWhenClientHeaderMissingBeta(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 客户端仅带 oauth beta不带 context-management-2025-06-27
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20")
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
)
require.NoError(t, err)
require.False(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
"API-key passthrough + 客户端未带 context-management beta → strip body 字段")
}
func TestBuildUpstreamRequestAnthropicAPIKeyPassthrough_PreservesContextManagementWhenClientHeaderHasBeta(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,context-management-2025-06-27")
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
)
require.NoError(t, err)
require.True(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
"API-key passthrough + 客户端带 context-management beta → 字段保留(不过度删除)")
}
func TestBuildCountTokensRequestAnthropicAPIKeyPassthrough_StripsContextManagementWhenClientHeaderMissingBeta(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,token-counting-2024-11-01")
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildCountTokensRequestAnthropicAPIKeyPassthrough(
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
)
require.NoError(t, err)
require.False(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
"count_tokens passthrough + 客户端未带 context-management beta → strip")
}
// ============================================================================
// 集成测试buildUpstreamRequest
// 全路径验证上游 outgoing body 与 anthropic-beta header 严格对称。
// 这个测试能挡住未来某人忘调 sanitize / 将 sanitize 挪到 CCH 之后 等 regression。
// ============================================================================
func TestBuildUpstreamRequest_OAuthMimicHaiku_StripsContextManagementEndToEnd(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
account := &Account{ID: 401, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "oauth-tok"},
Status: StatusActive,
Schedulable: true,
}
// haiku + mimic CC → final beta = HaikuBetaHeader不含 context-management
// body 必须 strip。
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"oauth-tok", "oauth", "claude-haiku-4-5", false, true, // mimicClaudeCode=true
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
"OAuth mimic + haiku 端到端outgoing body 不应含 context_management")
require.False(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
"对称约束outgoing anthropic-beta header 也不带 context-management beta")
}
func TestBuildUpstreamRequest_OAuthMimicNonHaiku_PreservesContextManagementEndToEnd(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
account := &Account{ID: 402, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "oauth-tok"},
Status: StatusActive,
Schedulable: true,
}
// sonnet + mimic CC → final beta = FullClaudeCodeMimicryBetas含 context-management
// body 保留。
body := []byte(`{"model":"claude-sonnet-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"oauth-tok", "oauth", "claude-sonnet-4-6", false, true,
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
"OAuth mimic + non-haikuoutgoing body 必须保留 context_management。")
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
"对称约束outgoing anthropic-beta header 同时含 context-management beta")
}
func TestBuildUpstreamRequest_OAuthTransparentHaikuWithRealCCBeta_PreservesField(t *testing.T) {
// 端到端验证:真 CC 客户端 + haiku + 客户端 header 带 context-management beta
// → final beta 透传 → 不应该过度删除 body 字段
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta",
"claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27")
account := &Account{ID: 403, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "oauth-tok"},
Status: StatusActive, Schedulable: true,
}
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"oauth-tok", "oauth", "claude-haiku-4-5", false, false, // mimicClaudeCode=false真 CC
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
"真 CC 透传路径:客户端 header 中的 context-management beta 必须保留")
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
"回归保护:真 CC + haiku + 客户端带 beta token 时clear_thinking_20251015 功能不能静默失效")
}
// CCH 顺序语义测试sanitize 必须在 signBillingHeaderCCH 之前,
// 否则签名的 hash 与最终发送的 body 不一致,被 Anthropic 判 third-party。
//
// 该测试不走 buildUpstreamRequest 完整路径(需要 mock SettingService 成本高),
// 而是直接验证两个顺序产生的 cch 不同,证明二者不可交换。
// 测试名本身是语义约束的文档化 marker。
func TestSanitizeMustBeBeforeCCHSigning_HashConsistency(t *testing.T) {
// 构造 body含 context_management + cch=00000 占位符
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92; cch=00000;"}],"messages":[]}`)
// 最终发送场景final beta 不含 context-management beta → sanitize 会 strip
finalBeta := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
extractCCH := func(t *testing.T, b []byte) string {
t.Helper()
m := regexp.MustCompile(`\bcch=([0-9a-fA-F]{5})\b`).FindSubmatch(b)
require.NotNil(t, m, "body 里找不到 cch=<5hex> %s", string(b))
return string(m[1])
}
// === 正确顺序sanitize → signBillingHeaderCCH ===
// 1. strip context_management
sanitizedFirst, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBeta)
require.True(t, changed)
require.False(t, gjson.GetBytes(sanitizedFirst, "context_management").Exists())
// 2. 基于“strip 后的 body”算 hash
correctFinal := signBillingHeaderCCH(sanitizedFirst)
correctCCH := extractCCH(t, correctFinal)
require.NotEqual(t, "00000", correctCCH, "placeholder 应被替换")
// === 错误顺序signBillingHeaderCCH → sanitize未来 regression 场景)===
// 1. 先基于“含 context_management 的 body”算 hash → cch=H_with
signedFirst := signBillingHeaderCCH(body)
wrongCCH := extractCCH(t, signedFirst)
require.NotEqual(t, "00000", wrongCCH)
// 2. 后 strip context_management → body 变化但 cch 仍是 H_with
wrongFinal, _ := sanitizeAnthropicBodyForBetaTokens(signedFirst, finalBeta)
wrongFinalCCH := extractCCH(t, wrongFinal)
// === 关键断言 ===
// 上游验证逻辑:将 outgoing body 的 cch 还原为 00000、重算 hash、与 cch 字段比较。
// 模拟上游验证:用发送 body 算出“期望的 cch”与发送 body 里的 cch 字段比。
recomputeExpected := func(b []byte, currentCCH string) string {
t.Helper()
// 把 cch=<currentCCH> 还原为 cch=00000
re := regexp.MustCompile(`(\bcch=)` + currentCCH + `(\b)`)
restored := re.ReplaceAll(b, []byte("${1}00000${2}"))
return extractCCH(t, signBillingHeaderCCH(restored))
}
// 正确顺序:发送 body 的 cch == 重算 hash → 上游验证过
require.Equal(t, correctCCH, recomputeExpected(correctFinal, correctCCH),
"正确顺序final body 里的 cch 与重算 hash 一致 → 上游验证通过")
// 错误顺序:发送 body 的 cch 是“含 ctx 算的”,但最终 body 不含 ctx → 重算 hash 不同
require.NotEqual(t, wrongFinalCCH, recomputeExpected(wrongFinal, wrongFinalCCH),
"错误顺序final body 里的 cch 是基于含 ctx 的 body 算的,"+
"但发送 body 已 strip ctx → 上游重算 hash 与 cch 不一致 → 被判 third-party。"+
"这是 buildUpstreamRequest / buildCountTokensRequest 里 sanitize 必须在 "+
"signBillingHeaderCCH 之前的原因。")
}
// count_tokens 主路径 E2E 集成测试
func TestBuildCountTokensRequest_OAuthMimicHaiku_PreservesContextManagementEndToEnd(t *testing.T) {
// count_tokens 路径下 mimic 不按 haiku 排除,始终注入 BetaContextManagement
// → sanitize 看到最终 beta header 含 context-management beta → 字段保留。
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
account := &Account{ID: 411, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
Credentials: map[string]any{"access_token": "oauth-tok"},
Status: StatusActive, Schedulable: true,
}
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildCountTokensRequest(
context.Background(), c, account, body,
"oauth-tok", "oauth", "claude-haiku-4-5", true, // mimicClaudeCode=true
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
"count_tokens mimic 始终注入 context-management beta")
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
"对称约束final beta 含 token 时 body 字段保留")
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaTokenCounting),
"count_tokens 路径必须含 token-counting beta")
}
func TestBuildCountTokensRequest_APIKeyHaiku_StripsContextManagementEndToEnd(t *testing.T) {
// API-key + haiku + 客户端 header 不带 context-management beta → final beta 不含 → strip
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{ID: 412, Platform: PlatformAnthropic, Type: AccountTypeAPIKey,
Credentials: map[string]any{"api_key": "sk-ant-xxx"},
Status: StatusActive, Schedulable: true,
}
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildCountTokensRequest(
context.Background(), c, account, body,
"sk-ant-xxx", "apikey", "claude-haiku-4-5", false,
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
"count_tokens API-key + 客户端未带 beta token → body strip")
}
// count_tokens passthrough preserve 测试
func TestBuildCountTokensRequestAnthropicAPIKeyPassthrough_PreservesContextManagementWhenClientHeaderHasBeta(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,context-management-2025-06-27,token-counting-2024-11-01")
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildCountTokensRequestAnthropicAPIKeyPassthrough(
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
)
require.NoError(t, err)
require.True(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
"count_tokens passthrough + 客户端带 context-management beta → 字段保留")
}
func TestBuildUpstreamRequest_APIKeyHaikuWithContextManagement_StripsField(t *testing.T) {
// API-key + haiku + body 带 context_management + 客户端 header 未带 context-management beta
// → final beta 不含 → body 字段被 strip
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
account := &Account{ID: 404, Platform: PlatformAnthropic, Type: AccountTypeAPIKey,
Credentials: map[string]any{"api_key": "sk-ant-xxx"},
Status: StatusActive, Schedulable: true,
}
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
svc := &GatewayService{cfg: &config.Config{}}
req, err := svc.buildUpstreamRequest(
context.Background(), c, account, body,
"sk-ant-xxx", "apikey", "claude-haiku-4-5", false, false,
)
require.NoError(t, err)
outBody := readUpstreamBodyForTest(t, req)
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
"API-key + haiku + 客户端未带 beta token → body 字段必须被 strip")
}

View File

@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@ -685,6 +686,69 @@ func removeThinkingDependentContextStrategies(body []byte) []byte {
return body
}
// anthropicBetaContextManagementToken 是 context_management 字段受的 beta token。
// 与 claude.BetaContextManagement 保持一致;在本文件本地定义以避免震荡
// claude package 的该常量含义。
const anthropicBetaContextManagementToken = "context-management-2025-06-27"
// sanitizeAnthropicBodyForBetaTokens 是对 Anthropic 直连路径上 body↔beta header
// **能力维度**对称约束的统一实现,与 Bedrock 路径的
// `sanitizeBedrockFieldsForBetaTokens` 对称。
//
// 问题场景:
// - context_management 是 Claude Code CLI 2.1.87+ 默认携带的 beta 字段
// (含 clear_thinking_20251015 等清理策略)
// - 其被 Anthropic 上游接受的前提是 anthropic-beta header 含
// `context-management-2025-06-27`
// - 若两侧不一致上游 Pydantic schema 拒收:
// "context_management: Extra inputs are not permitted"
//
// 本函数按最终发送的 anthropic-beta header 决定是否保留 body 中的
// context_management 字段:缺 beta token → strip。这将限制完全建立在
// "能力维度" 上,与 model 名 / token type / mimicry 子路径无关。
//
// 调用约束:必须在 CCH 签名之前调用,否则签名 hash 与最终 body
// 不一致,上游会以 third-party 拒收。
//
// 返回 (sanitized, changed)changed 表示是否发生实际删除,供调用方决定
// 是否重用原 body 引用。
func sanitizeAnthropicBodyForBetaTokens(body []byte, anthropicBetaHeader string) ([]byte, bool) {
if len(body) == 0 {
return body, false
}
if !gjson.GetBytes(body, "context_management").Exists() {
return body, false
}
if anthropicBetaTokensContains(anthropicBetaHeader, anthropicBetaContextManagementToken) {
return body, false
}
if b, err := sjson.DeleteBytes(body, "context_management"); err == nil {
return b, true
} else {
// 不应发生gjson 刚验证过字段存在 + body 是合法 JSON。如果 sjson 仍报错,
// 调用方会拿到 (body, false),但此前 computeFinalAnthropicBeta 已按“strip 后”
// 计算了 finalBeta——两侧会不一致。记录 warning 最小限度提醒运维。
logger.LegacyPrintf("service.gateway",
"[CtxMgmtSanitize] sjson.DeleteBytes failed unexpectedly: %v (body len=%d). "+
"body and final anthropic-beta header may be out of sync.", err, len(body))
}
return body, false
}
// anthropicBetaTokensContains 检测逗号分隔的 anthropic-beta header 是否含指定 token。
// 宋体空格宽容区分大小写Anthropic beta token 始终是小写)。
func anthropicBetaTokensContains(header, token string) bool {
if header == "" || token == "" {
return false
}
for _, part := range strings.Split(header, ",") {
if strings.TrimSpace(part) == token {
return true
}
}
return false
}
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
// signature/thought_signature validation issues involving tool blocks.
//

View File

@ -26,12 +26,9 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/claudemask"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/telemetry"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/cespare/xxhash/v2"
@ -282,10 +279,6 @@ func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account,
interesting := []string{
"user-agent",
"x-app",
"x-client-app",
"x-claude-remote-session-id",
"x-claude-remote-container-id",
"x-anthropic-additional-protection",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"anthropic-beta",
@ -404,10 +397,6 @@ var allowedHeaders = map[string]bool{
"accept-encoding": true,
"x-claude-code-session-id": true,
"x-client-request-id": true,
"x-client-app": true,
"x-claude-remote-session-id": true,
"x-claude-remote-container-id": true,
"x-anthropic-additional-protection": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
@ -428,16 +417,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// GetCascadeID 获取 Windsurf Cascade 会话 ID用于 LS 多轮复用)
// Get the Windsurf Cascade ID bound to a chat session for multi-turn LS reuse.
GetCascadeID(ctx context.Context, key string) (string, error)
// SetCascadeID 写入 Cascade 会话 ID
// Persist the Cascade session ID with the given TTL.
SetCascadeID(ctx context.Context, key string, cascadeID string, ttl time.Duration) error
// DeleteCascadeID 失效 Cascade 会话 IDpanel-not-found / 错误时调用)
// Invalidate the cached Cascade session ID on panel-not-found or upstream error.
DeleteCascadeID(ctx context.Context, key string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@ -600,8 +579,7 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
@ -647,7 +625,6 @@ func NewGatewayService(
channelService *ChannelService,
resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService,
rpmTokenBucketSvc *RPMTokenBucketService,
userPlatformQuotaRepo UserPlatformQuotaRepository,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
@ -675,7 +652,6 @@ func NewGatewayService(
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
rpmTokenBucket: rpmTokenBucketSvc,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute),
@ -1179,6 +1155,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
// context_managementthinking.type 为 enabled/adaptive 时,真实 CLI 会自动
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
// 客户端显式传了就透传;否则按 CLI 行为补齐。
//
// 注:本函数不按 model 名决定是否保留 context_management。“最终 beta
// header 不含 context-management-2025-06-27 时 strip 字段”的能力维度
// 对称约束由 sanitizeAnthropicBodyForBetaTokens 在 buildUpstreamRequest /
// buildCountTokensRequest 层统一执行,与 Bedrock 路径的
// sanitizeBedrockFieldsForBetaTokens 对称。
if !gjson.GetBytes(out, "context_management").Exists() {
thinkingType := gjson.GetBytes(out, "thinking.type").String()
if thinkingType == "enabled" || thinkingType == "adaptive" {
@ -1431,9 +1413,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
if platform == PlatformAnthropic && s.enableTierFallbackChain() {
return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs)
}
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
@ -2382,14 +2361,6 @@ func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, gr
return len(accounts) == 1
}
func (s *GatewayService) IsSingleWindsurfAccountGroup(ctx context.Context, groupID *int64) bool {
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformWindsurf, true)
if err != nil {
return false
}
return len(accounts) == 1
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
@ -2717,15 +2688,6 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
return err
}
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
if s.rpmTokenBucket == nil {
return nil
}
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash
@ -3754,12 +3716,6 @@ func summarizeSelectionFailureStats(stats selectionFailureStats) string {
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformWindsurf {
if strings.TrimSpace(requestedModel) == "" {
return true
}
return windsurf.ResolveModel(requestedModel) != ""
}
if account.Platform == PlatformAntigravity {
if strings.TrimSpace(requestedModel) == "" {
return true
@ -3784,12 +3740,6 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformWindsurf {
if strings.TrimSpace(requestedModel) == "" {
return true
}
return windsurf.ResolveModel(requestedModel) != ""
}
if account.Platform == PlatformAntigravity {
if strings.TrimSpace(requestedModel) == "" {
return true
@ -4148,12 +4098,10 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
// 模型仍通过 messages 接收完整指令,保留客户端功能
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
// 规范化 env 字段Platform/Shell/OS/路径),防止真实机器信息被 Anthropic 用作跨账号关联信号。
normalizedSystemText := NormalizeSystemPromptEnv(originalSystemText)
instrMsg, err1 := json.Marshal(map[string]any{
"role": "user",
"content": []map[string]any{
{"type": "text", "text": "[System Instructions]\n" + normalizedSystemText},
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
},
})
ackMsg, err2 := json.Marshal(map[string]any{
@ -4599,44 +4547,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 解析 TLS 指纹 profile同一请求生命周期内不变避免重试循环中重复解析
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
// Bootstrap 预热:模拟真实 CLI 启动时的 GET /api/claude_cli/bootstrap 调用
// 真实 CLI 在首次 messages 请求前 fire-and-forget 调用此端点。
if tokenType == "oauth" && token != "" {
instanceSalt := ""
useSharedUpstream := proxyURL != "" || tlsProfile != nil
if s.cfg != nil {
instanceSalt = s.cfg.Gateway.InstanceSalt
useSharedUpstream = useSharedUpstream || s.cfg.Gateway.TLSFingerprint.Enabled
}
TriggerBootstrapIfNeeded(account.ID, token, &BackgroundRequestOptions{
ProxyURL: proxyURL,
HTTPUpstream: s.httpUpstream,
TLSProfile: tlsProfile,
InstanceSalt: instanceSalt,
UseSharedUpstream: useSharedUpstream,
})
// OTEL telemetry: emit pre-request events (tengu_started, tengu_api_query etc.)
go telemetry.EmitPreRequest(
fmt.Sprintf("%d", account.ID),
token,
token,
reqModel,
getHeaderRaw(c.Request.Header, "anthropic-beta"),
)
}
// 调试日志:记录即将转发的账号信息
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL)
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
if account.IsContextCompressionEnabled() {
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
body = compressMessagesInBody(body, maxTok)
}
// 重试循环
var resp *http.Response
retryStart := time.Now()
@ -5032,18 +4948,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应
// OTEL telemetry: emit post-request events (fire-and-forget)
if tokenType == "oauth" && token != "" {
go telemetry.EmitPostRequest(
fmt.Sprintf("%d", account.ID),
token,
token,
reqModel,
getHeaderRaw(c.Request.Header, "anthropic-beta"),
resp.StatusCode,
)
}
// 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil {
parsed.OnUpstreamAccepted()
@ -5350,6 +5254,17 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
targetURL = validatedURL + "/v1/messages?beta=true"
}
// 能力维度 body sanitize透传路径上 anthropic-beta header 原样透传客户端值,
// 依此决定是否保留 body 中的 context_management。避免“客户端 body 带字段但
// header 忘记带 beta token”的客户端 bug 在透传场景下让上游 400。
clientBeta := ""
if c != nil && c.Request != nil {
clientBeta = getHeaderRaw(c.Request.Header, "anthropic-beta")
}
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
body = sanitized
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@ -6175,9 +6090,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号应用统一指纹和metadata重写受设置开关控制
var fingerprint *Fingerprint
enableFP, enableMPT, _ := true, false, false
enableFP, enableMPT, enableCCH := true, false, false
if s.settingService != nil {
enableFP, enableMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
@ -6208,9 +6123,33 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if fingerprint != nil {
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)。
// 无占位符时函数为 no-op故无需 enableCCH gate — 占位符存在即意味着必须签名。
body = signBillingHeaderCCH(body)
// === 计算最终 anthropic-beta header先于 body sanitize 与 CCH 签名)===
//
// 顺序约束:
// 1) 算 finalBeta纯函数不依赖 req.Headermimicry 路径会忽略客户端 beta
// 与原“OAuth + mimicClaudeCode 跳过白名单透传”行为对齐)
// 2) 按 finalBeta 做能力维度 body sanitize如 context-management beta 缺失 →
// strip body.context_management与 Bedrock 路径对称)
// 3) CCH 签名(必须使用 strip 后的 body否则 hash 与最终 body 不一致 →
// 被 Anthropic 判 third-party
// 4) NewRequestbody 至此最终敲定)
// 5) 透传白名单 / fingerprint / mimic header / 写入 finalBeta
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
effectiveDropSet := mergeDropSets(policyFilterSet)
finalBetaHeader, finalBetaShouldSet := s.computeFinalAnthropicBeta(
tokenType, mimicClaudeCode, modelID, clientHeaders, body, effectiveDropSet,
)
// 能力维度 body sanitize与最终 anthropic-beta header 对称
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBetaHeader); changed {
body = sanitized
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
if enableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
@ -6257,57 +6196,27 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeOAuthHeaderDefaults(req)
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
effectiveDropSet := mergeDropSets(policyFilterSet)
// 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta
if tokenType == "oauth" {
if mimicClaudeCode {
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Claude Code OAuth credentials are scoped to Claude Code.
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
// this as a legitimate Claude Code request; without it, the request is
// rejected as third-party ("out of extra usage").
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
}
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
setHeaderRaw(req.Header, "anthropic-beta", beta)
}
}
}
// OAuth + mimic Claude Code强制注入 CLI 指纹相关 header
// user-agent/x-stainless-*/x-app/Accept/x-stainless-helper-method/x-client-request-id
if tokenType == "oauth" && mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, reqStream)
}
// X-Claude-Code-Session-Id 头处理:
// Claude Code 2.1.145 SDK 内强制设置该头(`"X-Claude-Code-Session-Id":y_()`)。
// 优先取 metadata.user_id 中的 sessionIDOAuth mimic 场景缺失时兜底 UUID
// 避免上游基于该头缺失判定为第三方调用。
ensureClaudeCodeSessionID(req, body, tokenType, mimicClaudeCode)
// 写入最终 anthropic-beta header
// 注:透传分支白名单可能写入了客户端 anthropic-beta无条件 Del 一次再按 finalBeta
// 决定是否 set确保 dropSet 过滤后的结果一定覆盖客户端原始值。
deleteHeaderAllForms(req.Header, "anthropic-beta")
if finalBetaShouldSet {
setHeaderRaw(req.Header, "anthropic-beta", finalBetaHeader)
}
// x-client-request-id: 真实 CLI 每个请求生成新 UUID仅 1P
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
// === DEBUG: 打印上游转发请求headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
@ -6320,25 +6229,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
"enable_mpt": strconv.FormatBool(enableMPT),
})
// === Node.js 伪装验证与清理 ===
// 对于 OAuth 账号,验证请求是否正确伪装为 Node.js Claude Code 客户端
// 这是一个"猴子补丁",确保所有头都符合 Node.js 客户端标准
if tokenType == "oauth" && account.IsOAuth() {
// 1. 清理任何可能暴露 Go 客户端身份的头
if ua := req.Header.Get("User-Agent"); ua == "" || strings.Contains(ua, "Go-http-client") {
// User-Agent 缺失或包含 Go 指示,修复为 Node.js 格式
setHeaderRaw(req.Header, "User-Agent", claude.DefaultUserAgent())
}
// 2. 验证 Node.js 指纹完整性(用于调试日志)
if s.debugClaudeMimicEnabled() {
isValid, errors := claudemask.ValidateNodeEmulation(req)
if !isValid {
logger.LegacyPrintf("service.gateway", "⚠️ Node.js emulation validation failed: %v", errors)
}
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if c != nil && tokenType == "oauth" {
@ -6349,7 +6239,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
return req, nil
}
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
@ -6365,6 +6254,16 @@ func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
if err != nil {
return nil, err
}
// 能力维度 sanitizeVertex 路径上 anthropic-beta header 原样透传客户端值
// (下面白名单跳过 anthropic-version 但保留 anthropic-beta依此决定是否
// 保留 body 中的 context_management与 Anthropic 直连 / Bedrock 路径对称。
if c != nil && c.Request != nil {
clientBeta := getHeaderRaw(c.Request.Header, "anthropic-beta")
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(vertexBody, clientBeta); changed {
vertexBody = sanitized
}
}
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
if err != nil {
return nil, err
@ -6449,7 +6348,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.HaikuBetaHeader
}
return claude.GetOAuthBetaHeader(modelID)
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
@ -6466,7 +6365,10 @@ func requestNeedsBetaFeatures(body []byte) bool {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
return claude.GetAPIKeyBetaHeader(modelID)
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
}
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
@ -6476,7 +6378,7 @@ func applyClaudeOAuthHeaderDefaults(req *http.Request) {
if getHeaderRaw(req.Header, "Accept") == "" {
setHeaderRaw(req.Header, "Accept", "application/json")
}
for key, value := range claude.DefaultHeadersSnapshot() {
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
@ -6530,6 +6432,121 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
return strings.Join(out, ",")
}
// computeFinalAnthropicBeta 计算发往上游的最终 anthropic-beta header 值。
//
// 设计动机:将原本在 buildUpstreamRequest 内联在一起、依赖 req.Header 的
// anthropic-beta 计算逻辑抽成纯函数。这样调用方可以在 NewRequest 之前
// 就提前拿到最终 beta header进而能按它对 body 做能力维度 sanitize 后再做
// CCH 签名——一举修复了以下之前由顺序依赖导致的能力维度 sanitize
// 无法部署的问题(签名与最终 body 不一致可以被判 third-party
//
// 返回 (value, shouldSet)
// - shouldSet=false 意为“不主动设置 anthropic-beta header”与原代码“
// API-key 账号 + 客户端未传 anthropic-beta + InjectBetaForAPIKey 未开启或
// requestNeedsBetaFeatures=false”的行为对齐。
// - shouldSet=true 时 value 可能为空字符串(例如客户端透传的 beta 被 dropSet
// 全部过滤掉),这与原代码中 setHeaderRaw 的结果一致。
//
// clientHeaders 是客户端原始 HTTP header通常为 c.Request.Headernil 时按“客户端
// 未传”处理。body 是已经 metadata 重写 / billing version sync 之后但未 sanitize 上游
// 不兼容字段之前的版本。
func (s *GatewayService) computeFinalAnthropicBeta(
tokenType string,
mimicClaudeCode bool,
modelID string,
clientHeaders http.Header,
body []byte,
effectiveDropSet map[string]struct{},
) (string, bool) {
clientBeta := ""
if clientHeaders != nil {
clientBeta = getHeaderRaw(clientHeaders, "anthropic-beta")
}
if tokenType == "oauth" {
if mimicClaudeCode {
// mimic 路径原代码跳过白名单透传incomingBeta 总是空字符串。
// 这里传空 string 以严格对齐原行为。
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
return mergeAnthropicBetaDropping(requiredBetas, "", effectiveDropSet), true
}
// 真 Claude Code 客户端透传路径
return stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBeta), effectiveDropSet), true
}
// API-key accounts
if clientBeta != "" {
return stripBetaTokensWithSet(clientBeta, effectiveDropSet), true
}
if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
return beta, true
}
}
}
return "", false
}
// computeFinalCountTokensAnthropicBeta 是 count_tokens 路径上 anthropic-beta header 的
// 计算纯函数。语义与 computeFinalAnthropicBeta 对齐,但备份了 count_tokens 独有的
// 两条特殊规则:
//
// - OAuth mimicrequiredBetas 为 FullClaudeCodeMimicryBetas + BetaTokenCounting
// (与 messages 不同的是:不按 haiku 排除count_tokens 始终携带 token-counting beta
// - OAuth 透传 + 客户端未传 anthropic-beta补齐 CountTokensBetaHeader
// - OAuth 透传 + 客户端传了:补齐 BetaTokenCounting如果未含
//
// 返回语义同 computeFinalAnthropicBeta。
func (s *GatewayService) computeFinalCountTokensAnthropicBeta(
tokenType string,
mimicClaudeCode bool,
modelID string,
clientHeaders http.Header,
body []byte,
effectiveDropSet map[string]struct{},
) (string, bool) {
clientBeta := ""
if clientHeaders != nil {
clientBeta = getHeaderRaw(clientHeaders, "anthropic-beta")
}
if tokenType == "oauth" {
if mimicClaudeCode {
// 与原代码严格等价original buildCountTokensRequest 在 count_tokens mimic
// 分支上**不**会跳过白名单透传(与 messages mimic 路径不同),所以
// incomingBeta = req.Header[anthropic-beta] = 客户端透传过来的 client beta。
// 重构后直接从 clientHeaders 拿同一个值,保持行为一致。
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
return mergeAnthropicBetaDropping(requiredBetas, clientBeta, effectiveDropSet), true
}
if clientBeta == "" {
return claude.CountTokensBetaHeader, true
}
beta := s.getBetaHeader(modelID, clientBeta)
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
return stripBetaTokensWithSet(beta, effectiveDropSet), true
}
// API-key accounts
if clientBeta != "" {
return stripBetaTokensWithSet(clientBeta, effectiveDropSet), true
}
if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
return beta, true
}
}
}
return "", false
}
// stripBetaTokens removes the given beta tokens from a comma-separated header value.
func stripBetaTokens(header string, tokens []string) string {
if header == "" || len(tokens) == 0 {
@ -6810,7 +6827,7 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
applyClaudeOAuthHeaderDefaults(req)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
// 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App"
for key, value := range claude.DefaultHeadersSnapshot() {
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
@ -9072,9 +9089,6 @@ func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context,
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
if account.Platform == PlatformWindsurf {
return windsurf.ResolveModel(requestedModel)
}
if account.Platform == PlatformAntigravity {
return mapAntigravityModel(account, requestedModel)
}
@ -9158,7 +9172,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// Antigravity 账户不支持 count_tokens返回 404 让客户端 fallback 到本地估算。
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
if account.Platform == PlatformAntigravity || account.Platform == PlatformWindsurf {
if account.Platform == PlatformAntigravity {
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil
}
@ -9434,6 +9448,16 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
body = sanitizeCountTokensRequestBody(body)
// 同 buildUpstreamRequestAnthropicAPIKeyPassthrough能力维度 sanitize。
clientBeta := ""
if c != nil && c.Request != nil {
clientBeta = getHeaderRaw(c.Request.Header, "anthropic-beta")
}
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
body = sanitized
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
if err != nil {
@ -9501,9 +9525,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID受设置开关控制
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT, _ := true, false, false
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
if s.settingService != nil {
ctEnableFP, ctEnableMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
@ -9525,8 +9549,23 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if ctFingerprint != nil && ctEnableFP {
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
}
// 无占位符时函数为 no-op故无需 ctEnableCCH gate。
body = signBillingHeaderCCH(body)
// === 计算最终 anthropic-beta header先于 body sanitize 与 CCH 签名)===
// 顺序约束同 buildUpstreamRequest。
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
finalBetaHeader, finalBetaShouldSet := s.computeFinalCountTokensAnthropicBeta(
tokenType, mimicClaudeCode, modelID, clientHeaders, body, ctEffectiveDropSet,
)
// 能力维度 body sanitize与最终 anthropic-beta header 对称
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBetaHeader); changed {
body = sanitized
}
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
body = sanitizeCountTokensRequestBody(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
@ -9567,50 +9606,24 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeOAuthHeaderDefaults(req)
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
if mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
if clientBetaHeader == "" {
setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader)
} else {
beta := s.getBetaHeader(modelID, clientBetaHeader)
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
}
}
} else {
// API-key accounts: apply beta policy filter to strip controlled tokens
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
setHeaderRaw(req.Header, "anthropic-beta", beta)
}
}
}
// OAuth + mimic Claude Code强制注入 CLI 指纹 header
if tokenType == "oauth" && mimicClaudeCode {
applyClaudeCodeMimicHeaders(req, false)
}
// X-Claude-Code-Session-Id 头处理count_tokens 路径):
// 与 messages 路径保持同样逻辑OAuth mimic 场景缺失时兜底 UUID。
ensureClaudeCodeSessionID(req, body, tokenType, mimicClaudeCode)
// 写入最终 anthropic-beta headerDel 一次避免白名单透传值残留)
deleteHeaderAllForms(req.Header, "anthropic-beta")
if finalBetaShouldSet {
setHeaderRaw(req.Header, "anthropic-beta", finalBetaHeader)
}
// x-client-request-idcount_tokens 路径)
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
if c != nil && tokenType == "oauth" {
@ -9623,6 +9636,25 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
return req, nil
}
func sanitizeCountTokensRequestBody(body []byte) []byte {
out := body
for _, path := range []string{
"temperature",
"top_p",
"top_k",
"stream",
"stop_sequences",
"stop",
} {
if gjson.GetBytes(out, path).Exists() {
if next, ok := deleteJSONPathBytes(out, path); ok {
out = next
}
}
}
return out
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{

View File

@ -2031,6 +2031,22 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
parts := extractGeminiParts(geminiResp)
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
// Close an open tool_use block before starting text, mirroring
// the functionCall branch (which closes open text blocks) and
// the chat-completions sibling's closeOpenTool(). Otherwise a
// tool→text sequence keeps the tool_use block open while the
// text block starts, emitting overlapping Anthropic content
// blocks that violate the SSE contract.
if openToolIndex >= 0 {
writeSSE(c.Writer, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": openToolIndex,
})
openToolIndex = -1
openToolName = ""
seenToolJSON = ""
}
delta, newSeen := computeGeminiTextDelta(seenText, text)
seenText = newSeen
if delta == "" {

View File

@ -832,3 +832,108 @@ func TestParseGeminiRateLimitResetTime(t *testing.T) {
})
}
}
// TestGeminiMessagesHandleStreamingResponse_ClosesToolBlockBeforeText guards the
// tool→text ordering in the Gemini→Anthropic (messages) streaming bridge. When
// Gemini emits a functionCall part followed by a text part, the tool_use content
// block must be closed before the text block opens; otherwise the Anthropic SSE
// stream contains overlapping content blocks. The chat-completions sibling
// already enforces this via closeOpenTool().
func TestGeminiMessagesHandleStreamingResponse_ClosesToolBlockBeforeText(t *testing.T) {
gin.SetMode(gin.TestMode)
upstreamBody := `data: {"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"SF"}}}]}}]}` + "\n\n" +
`data: {"candidates":[{"content":{"parts":[{"text":"All done."}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}` + "\n\n" +
"data: [DONE]\n\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GeminiMessagesCompatService{}
result, err := svc.handleStreamingResponse(c, resp, time.Now(), "claude-3-5-sonnet")
require.NoError(t, err)
require.NotNil(t, result)
events := parseAnthropicContentBlockEvents(t, rec.Body.String())
// Anthropic allows at most one content block open at a time: every
// content_block_start must be matched by a content_block_stop before the
// next start. Replay the lifecycle and assert there is no overlap.
open := -1
blockTypes := map[int]string{}
textStarted := false
toolClosed := false
toolClosedBeforeText := false
for _, ev := range events {
switch ev.event {
case "content_block_start":
require.Equalf(t, -1, open,
"content block %d opened while block %d was still open (overlapping blocks)", ev.index, open)
open = ev.index
blockTypes[ev.index] = ev.blockType
if ev.blockType == "text" {
textStarted = true
if toolClosed {
toolClosedBeforeText = true
}
}
case "content_block_stop":
require.Equalf(t, open, ev.index,
"content_block_stop index %d does not match the open block %d", ev.index, open)
if blockTypes[ev.index] == "tool_use" {
toolClosed = true
}
open = -1
}
}
require.True(t, textStarted, "expected a text content block to be emitted after the tool call")
require.True(t, toolClosedBeforeText, "tool_use block must be closed before the text block starts")
require.Equal(t, -1, open, "stream ended with a content block still open")
}
type anthropicContentBlockEvent struct {
event string
index int
blockType string
}
// parseAnthropicContentBlockEvents extracts content_block_start/stop events (with
// their index and, for starts, the content block type) from an Anthropic SSE body.
func parseAnthropicContentBlockEvents(t *testing.T, raw string) []anthropicContentBlockEvent {
t.Helper()
var events []anthropicContentBlockEvent
for _, chunk := range strings.Split(raw, "\n\n") {
var eventName, dataLine string
for _, line := range strings.Split(chunk, "\n") {
switch {
case strings.HasPrefix(line, "event:"):
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
case strings.HasPrefix(line, "data:"):
dataLine = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
}
if eventName != "content_block_start" && eventName != "content_block_stop" {
continue
}
var payload struct {
Index int `json:"index"`
ContentBlock struct {
Type string `json:"type"`
} `json:"content_block"`
}
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
events = append(events, anthropicContentBlockEvent{
event: eventName,
index: payload.Index,
blockType: payload.ContentBlock.Type,
})
}
return events
}

View File

@ -117,6 +117,20 @@ func addHeaderRaw(h http.Header, key, value string) {
h[key] = append(h[key], value)
}
// deleteHeaderAllForms removes a header in all common key forms (raw, wire casing,
// canonical) so subsequent setHeaderRaw will not coexist with a passthrough value
// written under a different casing.
func deleteHeaderAllForms(h http.Header, key string) {
if h == nil || key == "" {
return
}
h.Del(key) // canonical
delete(h, key)
if wk := resolveWireCasing(key); wk != key {
delete(h, wk)
}
}
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
// between Go canonical keys, wire casing keys, and raw keys:
// 1. exact key as provided

View File

@ -44,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
RequiredCapability OpenAIEndpointCapability
RequiredImageCapability OpenAIImagesCapability
RequireCompact bool
ExcludedIDs map[int64]struct{}
@ -263,12 +264,13 @@ func (s *defaultOpenAIAccountScheduler) Select(
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
selection, err := s.service.selectAccountByPreviousResponseIDForCapability(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
req.RequiredCapability,
req.RequireCompact,
)
if err != nil {
@ -363,12 +365,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact, req.RequiredCapability)
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result != nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
@ -794,11 +795,11 @@ func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
@ -934,11 +935,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency非 EffectiveLoadFactor因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
@ -977,6 +978,13 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
return false
}
// Quota auto-pause must be evaluated during the initial filter too. Without it the
// TopK candidate pool can be filled with paused accounts and the later fresh/DB
// rechecks won't reach healthy accounts that fell outside TopK — manifesting as
// "no available accounts" even though healthy ones exist.
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
return false
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false
}
@ -985,7 +993,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
return false
}
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
return accountSupportsOpenAICapabilities(account, req.RequiredCapability, req.RequiredImageCapability)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@ -1108,7 +1116,21 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", "", requireCompact)
}
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForCapability(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requiredCapability OpenAIEndpointCapability,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, requiredCapability, "", requireCompact)
}
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
@ -1119,13 +1141,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
excludedIDs map[int64]struct{},
requiredCapability OpenAIImagesCapability,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", requiredCapability, false)
if err == nil && selection != nil && selection.Account != nil {
return selection, decision, nil
}
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basicOAuth 账号)
if requiredCapability == OpenAIImagesCapabilityNative {
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", OpenAIImagesCapabilityBasic, false)
}
return selection, decision, err
}
@ -1138,9 +1160,11 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
requiredCapability OpenAIEndpointCapability,
requiredImageCapability OpenAIImagesCapability,
requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
ctx = s.withOpenAIQuotaAutoPauseContext(ctx)
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
@ -1148,14 +1172,14 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
if accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
@ -1173,14 +1197,15 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
for {
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability)
if err != nil {
return nil, decision, err
}
if selection == nil || selection.Account == nil {
return selection, decision, nil
}
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) &&
accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) {
return selection, decision, nil
}
if selection.ReleaseFunc != nil {
@ -1217,12 +1242,21 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
RequiredCapability: requiredCapability,
RequiredImageCapability: requiredImageCapability,
RequireCompact: requireCompact,
ExcludedIDs: excludedIDs,
})
}
func accountSupportsOpenAICapabilities(account *Account, requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability) bool {
if account == nil {
return false
}
return account.SupportsOpenAIEndpointCapability(requiredCapability) &&
account.SupportsOpenAIImageCapability(requiredImageCapability)
}
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
if len(excludedIDs) == 0 {
return nil

View File

@ -417,6 +417,64 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10110)
accounts := []Account{
{
ID: 36031,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions"},
},
},
{
ID: 36032,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions", "embeddings"},
},
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
ctx,
&groupID,
"",
"",
"text-embedding-3-small",
nil,
OpenAIUpstreamTransportHTTPSSE,
OpenAIEndpointCapabilityEmbeddings,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36032), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
@ -482,6 +540,141 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
require.True(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10111)
accounts := []Account{
{
ID: 37011,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions"},
},
},
{
ID: 37012,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions", "embeddings"},
},
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
ctx,
&groupID,
"",
"",
"text-embedding-3-small",
nil,
OpenAIUpstreamTransportHTTPSSE,
OpenAIEndpointCapabilityEmbeddings,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37012), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.Equal(t, 1, decision.CandidateCount)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyStickyBindings(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10112)
accounts := []Account{
{
ID: 37021,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions"},
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 37022,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions", "embeddings"},
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cfg := newSchedulerTestOpenAIWSV2Config()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_embeddings": 37021,
},
}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_embeddings_chat_only", 37021, time.Hour))
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
ctx,
&groupID,
"resp_embeddings_chat_only",
"session_hash_embeddings",
"text-embedding-3-small",
nil,
OpenAIUpstreamTransportHTTPSSE,
OpenAIEndpointCapabilityEmbeddings,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37022), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickyPreviousHit)
require.False(t, decision.StickySessionHit)
require.Equal(t, int64(37022), cache.sessionBindings["openai:session_hash_embeddings"])
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
@ -522,6 +715,224 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AutoPauseBy5hThreshold(t *testing.T) {
ctx := context.Background()
primary := Account{
ID: 35001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 95.0,
"auto_pause_5h_threshold": 0.95,
},
}
secondary := Account{ID: 35002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35002), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AllowsBelow5hThreshold(t *testing.T) {
ctx := context.Background()
primary := Account{
ID: 35101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 80.0,
"auto_pause_5h_threshold": 0.95,
},
}
secondary := Account{ID: 35102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35101), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AutoPauseBy7dThreshold(t *testing.T) {
ctx := context.Background()
primary := Account{
ID: 35201,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_7d_used_percent": 95.0,
"auto_pause_7d_threshold": 0.95,
},
}
secondary := Account{ID: 35202, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35202), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_UnconfiguredThresholdKeepsLegacyBehavior(t *testing.T) {
ctx := context.Background()
primary := Account{ID: 35301, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, Extra: map[string]any{"codex_5h_used_percent": 99.0, "codex_7d_used_percent": 99.0}}
secondary := Account{ID: 35302, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35301), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_UsesGlobalDefaultThreshold(t *testing.T) {
ctx := withOpenAIQuotaAutoPauseSettings(context.Background(), OpsOpenAIAccountQuotaAutoPauseSettings{DefaultThreshold5h: 0.95})
primary := Account{
ID: 35401,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 95.0,
},
}
secondary := Account{ID: 35402, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35402), account.ID)
}
// Regression: a per-account explicit-disable flag exempts the account from auto-pause
// even when a global default threshold is set. Without this, "leave threshold blank"
// silently falls back to global default and admins have no way to whitelist a single
// account.
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_PerAccountDisableOverridesGlobalDefault(t *testing.T) {
ctx := withOpenAIQuotaAutoPauseSettings(context.Background(), OpsOpenAIAccountQuotaAutoPauseSettings{DefaultThreshold5h: 0.95})
// Account has high usage AND no per-account threshold (would normally fall back to
// the global default and get paused), but the explicit disable flag is set.
primary := Account{
ID: 35701,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 99.0,
"auto_pause_5h_disabled": true,
},
}
secondary := Account{ID: 35702, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35701), account.ID)
}
// Disable is per-window: disabling only 5h must still allow 7d auto-pause to fire.
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_PerWindowDisableScoped(t *testing.T) {
ctx := context.Background()
primary := Account{
ID: 35801,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 99.0,
"codex_7d_used_percent": 99.0,
"auto_pause_5h_disabled": true,
"auto_pause_7d_threshold": 0.95,
},
}
secondary := Account{ID: 35802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35802), account.ID, "7d auto-pause must still fire even though 5h is disabled")
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_StaleUsageWindowResetSkipsPause(t *testing.T) {
ctx := context.Background()
// Usage is over threshold but the window's reset time has already passed, so the
// cached percentage is stale (the real window rolled over) and the account must NOT
// stay paused — otherwise it could be skipped forever with no traffic to refresh it.
primary := Account{
ID: 35501,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 99.0,
"auto_pause_5h_threshold": 0.95,
"codex_5h_reset_at": time.Now().Add(-time.Minute).Format(time.RFC3339),
},
}
secondary := Account{ID: 35502, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35501), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_FreshUsageWindowStillPauses(t *testing.T) {
ctx := context.Background()
// Same as above but the window has not reset yet, so the account stays paused.
primary := Account{
ID: 35601,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 99.0,
"auto_pause_5h_threshold": 0.95,
"codex_5h_reset_at": time.Now().Add(time.Hour).Format(time.RFC3339),
},
}
secondary := Account{ID: 35602, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(35602), account.ID)
}
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10102)
@ -1069,6 +1480,85 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
}
// Regression: TopK initial filter must drop quota-auto-paused accounts. Otherwise
// the candidate pool is filled with paused accounts, healthy accounts fall outside
// TopK, and the scheduler returns "no available accounts" even though healthy ones
// exist.
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKExcludesQuotaPaused(t *testing.T) {
ctx := context.Background()
groupID := int64(110)
accounts := []Account{
{
ID: 37001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"codex_5h_used_percent": 96.0,
"auto_pause_5h_threshold": 0.95,
},
},
{
ID: 37002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 1 // TopK=1 makes the bug fatal: paused account would crowd out the healthy one entirely
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
37001: {AccountID: 37001, LoadRate: 5, WaitingCount: 0},
37002: {AccountID: 37002, LoadRate: 5, WaitingCount: 0},
},
acquireResults: map[int64]bool{
37002: true,
},
}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
false,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
// Only the healthy account should ever enter the candidate pool; the paused one
// must be filtered out at the initial-filter stage.
require.Equal(t, 1, decision.CandidateCount)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
ctx := context.Background()
groupID := int64(12)

View File

@ -13,6 +13,10 @@ const (
CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched"
// CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。
CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched"
// CodexClientRestrictionReasonMatchedAllowedClient 表示请求命中账号级额外放行的命名客户端预设。
CodexClientRestrictionReasonMatchedAllowedClient = "allowed_client_matched"
// CodexClientRestrictionReasonMatchedGlobalAllowedClient 表示请求命中全局额外放行的命名客户端预设。
CodexClientRestrictionReasonMatchedGlobalAllowedClient = "global_allowed_client_matched"
// CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。
CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched"
// CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。
@ -28,7 +32,7 @@ type CodexClientRestrictionDetectionResult struct {
// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。
type CodexClientRestrictionDetector interface {
Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult
Detect(c *gin.Context, account *Account, globalAllowedClients []string) CodexClientRestrictionDetectionResult
}
// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。
@ -40,7 +44,7 @@ func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexCli
return &OpenAICodexClientRestrictionDetector{cfg: cfg}
}
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account, globalAllowedClients []string) CodexClientRestrictionDetectionResult {
if account == nil || !account.IsCodexCLIOnlyEnabled() {
return CodexClientRestrictionDetectionResult{
Enabled: false,
@ -78,6 +82,26 @@ func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *A
}
}
// 官方客户端白名单未命中时,先尝试账号级额外放行的命名客户端预设(如 Claude Code codex 插件)。
if allowed := account.GetCodexCLIOnlyAllowedClients(); len(allowed) > 0 &&
openai.MatchAllowedClients(userAgent, originator, allowed) {
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonMatchedAllowedClient,
}
}
// 再尝试由更高作用域(全局设置)注入的额外放行客户端列表。
if len(globalAllowedClients) > 0 &&
openai.MatchAllowedClients(userAgent, originator, globalAllowedClients) {
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: true,
Reason: CodexClientRestrictionReasonMatchedGlobalAllowedClient,
}
}
return CodexClientRestrictionDetectionResult{
Enabled: true,
Matched: false,

View File

@ -30,7 +30,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account)
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account, nil)
require.False(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
@ -44,7 +44,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account)
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
@ -58,7 +58,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account)
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
@ -72,7 +72,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account)
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
@ -86,7 +86,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account)
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason)
@ -100,7 +100,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, nil)
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
@ -116,9 +116,146 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
})
}
func TestOpenAICodexClientRestrictionDetector_Detect_AllowedClients(t *testing.T) {
gin.SetMode(gin.TestMode)
const (
claudeCodeUA = "Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)"
claudeCodeOriginator = "Claude Code"
)
t.Run("配置 claude_code 白名单且命中真实签名时放行", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
"codex_cli_only_allowed_clients": []any{"claude_code"},
},
}
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedAllowedClient, result.Reason)
})
t.Run("配置白名单但伪造 originator 仍拒绝", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
"codex_cli_only_allowed_clients": []any{"claude_code"},
},
}
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, "my_client"), account, nil)
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
})
t.Run("未配置白名单时 Claude Code 签名仍拒绝", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
})
t.Run("未开启 codex_cli_only 时白名单不参与,直接绕过", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
}
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
require.False(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
})
t.Run("全局列表含 claude_code + 命中签名 → 放行(global)", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)", "Claude Code"),
account,
[]string{"claude_code"},
)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedGlobalAllowedClient, result.Reason)
})
t.Run("全局列表含 claude_code + 非签名 → 403", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, []string{"claude_code"})
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
})
t.Run("全局列表为空 + 账号未配 → 403", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{"codex_cli_only": true},
}
result := detector.Detect(
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos) (Claude Code; 1.0.4)", "Claude Code"),
account,
nil,
)
require.True(t, result.Enabled)
require.False(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
})
t.Run("账号白名单优先于全局列表reason=account", func(t *testing.T) {
detector := NewOpenAICodexClientRestrictionDetector(nil)
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
"codex_cli_only_allowed_clients": []any{"claude_code"},
},
}
result := detector.Detect(
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos) (Claude Code; 1.0.4)", "Claude Code"),
account,
[]string{"claude_code"},
)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonMatchedAllowedClient, result.Reason)
})
}

View File

@ -901,7 +901,17 @@ func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMet
}
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
return s.getCodexClientRestrictionDetector().Detect(c, account)
var globalAllowedClients []string
if account != nil && account.IsCodexCLIOnlyEnabled() && s != nil && s.settingService != nil {
ctx := context.Background()
if c != nil && c.Request != nil {
ctx = c.Request.Context()
}
if s.settingService.IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx) {
globalAllowedClients = []string{openai.AllowedClientClaudeCode}
}
}
return s.getCodexClientRestrictionDetector().Detect(c, account, globalAllowedClients)
}
func getAPIKeyIDFromContext(c *gin.Context) int64 {
@ -959,6 +969,7 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
}
log := logger.FromContext(ctx).With(fields...)
if result.Matched {
log.Info("OpenAI codex_cli_only 放行请求")
return
}
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
@ -1279,7 +1290,7 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0)
return s.selectAccountForModelWithExclusions(s.withOpenAIQuotaAutoPauseContext(ctx), groupID, sessionHash, requestedModel, excludedIDs, false, 0, "")
}
// noAvailableOpenAISelectionError builds the standard "no account available" error
@ -1312,19 +1323,228 @@ func openAICompactSupportTier(account *Account) int {
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
// compact-support checks used during account selection.
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool {
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) bool {
if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
return false
}
if paused, reason := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
// Debug level: this fires per-candidate on the scheduling hot path, so Info
// would amplify into log spam once several accounts cross the threshold.
slog.Debug("account_auto_paused_by_quota",
"account_id", account.ID,
"window", reason.window,
"threshold", reason.threshold,
"utilization", reason.utilization,
)
return false
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return false
}
if !account.SupportsOpenAIEndpointCapability(requiredCapability) {
return false
}
if requireCompact && openAICompactSupportTier(account) == 0 {
return false
}
return true
}
type openAIQuotaAutoPauseDecision struct {
window string
threshold float64
utilization float64
}
func shouldAutoPauseOpenAIAccountByQuota(ctx context.Context, account *Account) (bool, openAIQuotaAutoPauseDecision) {
if account == nil || !account.IsOpenAI() {
return false, openAIQuotaAutoPauseDecision{}
}
// Per-account explicit-disable flags must take precedence over the global default.
// Without these, leaving the account threshold blank means "use global default",
// so an admin has no way to exempt a single account from auto-pause once a global
// default exists. The disable flag is per-window so an account can opt out of
// only 5h or only 7d auto-pause.
disabled5h := resolveAccountExtraBool(account.Extra, "auto_pause_5h_disabled")
disabled7d := resolveAccountExtraBool(account.Extra, "auto_pause_7d_disabled")
threshold5h, threshold7d := resolveOpenAIQuotaAutoPauseThresholds(ctx, account)
now := time.Now()
if !disabled5h && threshold5h > 0 {
if utilization, ok := resolveOpenAIQuotaUtilization(account.Extra, "5h", now); ok && utilization >= threshold5h {
return true, openAIQuotaAutoPauseDecision{window: "5h", threshold: threshold5h, utilization: utilization}
}
}
if !disabled7d && threshold7d > 0 {
if utilization, ok := resolveOpenAIQuotaUtilization(account.Extra, "7d", now); ok && utilization >= threshold7d {
return true, openAIQuotaAutoPauseDecision{window: "7d", threshold: threshold7d, utilization: utilization}
}
}
return false, openAIQuotaAutoPauseDecision{}
}
// resolveAccountExtraBool reads a bool-like value from account extra, tolerating
// the few shapes JSON unmarshalling may produce (real bool, "true"/"false"
// strings, 0/1 numbers).
func resolveAccountExtraBool(extra map[string]any, key string) bool {
if len(extra) == 0 {
return false
}
value, ok := extra[key]
if !ok || value == nil {
return false
}
switch v := value.(type) {
case bool:
return v
case string:
parsed, err := strconv.ParseBool(strings.TrimSpace(v))
return err == nil && parsed
case float64:
return v != 0
case float32:
return v != 0
case int:
return v != 0
case int64:
return v != 0
case json.Number:
if i, err := v.Int64(); err == nil {
return i != 0
}
}
return false
}
func resolveOpenAIQuotaAutoPauseThresholds(ctx context.Context, account *Account) (float64, float64) {
threshold5h, _ := resolveAccountExtraNumber(account.Extra, "auto_pause_5h_threshold")
threshold7d, _ := resolveAccountExtraNumber(account.Extra, "auto_pause_7d_threshold")
threshold5h = clamp01(threshold5h)
threshold7d = clamp01(threshold7d)
if threshold5h > 0 && threshold7d > 0 {
return threshold5h, threshold7d
}
settings := openAIQuotaAutoPauseSettingsFromContext(ctx)
if threshold5h <= 0 {
threshold5h = clamp01(settings.DefaultThreshold5h)
}
if threshold7d <= 0 {
threshold7d = clamp01(settings.DefaultThreshold7d)
}
return threshold5h, threshold7d
}
func resolveAccountExtraNumber(extra map[string]any, keys ...string) (float64, bool) {
if len(extra) == 0 {
return 0, false
}
for _, key := range keys {
value, ok := extra[key]
if !ok || value == nil {
continue
}
switch v := value.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int64:
return float64(v), true
case json.Number:
parsed, err := v.Float64()
if err == nil {
return parsed, true
}
case string:
parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err == nil {
return parsed, true
}
}
}
return 0, false
}
// resolveOpenAIQuotaUtilization returns the current utilization ratio (0..1) for the
// given Codex usage window. ok=false means there is no usable signal to pause on:
// either no snapshot exists, or the window has already rolled over so the cached
// percentage is stale. The stale guard matters because a paused account stops
// receiving requests, so its snapshot is never refreshed from upstream headers —
// without this check an old used_percent would keep the account paused forever even
// after the real window reset.
func resolveOpenAIQuotaUtilization(extra map[string]any, window string, now time.Time) (float64, bool) {
usedPercent := readOpenAIQuotaUsedPercent(extra, window)
if usedPercent <= 0 {
return 0, false
}
if openAIQuotaWindowReset(extra, window, now) {
return 0, false
}
return usedPercent / 100, true
}
// openAIQuotaWindowReset reports whether the Codex usage window's reset time has
// already passed relative to now. It prefers the absolute codex_<window>_reset_at
// timestamp and falls back to codex_<window>_reset_after_seconds anchored at
// codex_usage_updated_at, mirroring AccountUsageService's window-progress logic.
func openAIQuotaWindowReset(extra map[string]any, window string, now time.Time) bool {
if len(extra) == 0 {
return false
}
if resetAtRaw, ok := extra["codex_"+window+"_reset_at"]; ok {
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
return !now.Before(resetAt)
}
}
resetAfter := parseExtraInt(extra["codex_"+window+"_reset_after_seconds"])
if resetAfter <= 0 {
return false
}
base := now
if updatedRaw, ok := extra["codex_usage_updated_at"]; ok {
if updatedAt, err := parseTime(fmt.Sprint(updatedRaw)); err == nil {
base = updatedAt
}
}
resetAt := base.Add(time.Duration(resetAfter) * time.Second)
return !now.Before(resetAt)
}
func readOpenAIQuotaUsedPercent(extra map[string]any, window string) float64 {
if len(extra) == 0 {
return 0
}
if value, ok := resolveAccountExtraNumber(extra, "codex_"+window+"_used_percent"); ok {
return value
}
return 0
}
type openAIQuotaAutoPauseCtxKey struct{}
func withOpenAIQuotaAutoPauseSettings(ctx context.Context, settings OpsOpenAIAccountQuotaAutoPauseSettings) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, openAIQuotaAutoPauseCtxKey{}, settings)
}
func openAIQuotaAutoPauseSettingsFromContext(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
if ctx == nil {
return OpsOpenAIAccountQuotaAutoPauseSettings{}
}
settings, _ := ctx.Value(openAIQuotaAutoPauseCtxKey{}).(OpsOpenAIAccountQuotaAutoPauseSettings)
return settings
}
func (s *OpenAIGatewayService) withOpenAIQuotaAutoPauseContext(ctx context.Context) context.Context {
if s == nil || s.settingService == nil {
return ctx
}
return withOpenAIQuotaAutoPauseSettings(ctx, s.settingService.GetOpenAIQuotaAutoPauseSettings(ctx))
}
// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known
// compact support are tried first, followed by unknown, then explicitly unsupported.
// The relative order within each tier is preserved.
@ -1366,7 +1586,7 @@ func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedMode
return upstreamModel
}
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) {
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
@ -1376,7 +1596,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 1. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil {
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability); account != nil {
return account, nil
}
@ -1389,7 +1609,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact)
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact, requiredCapability)
if selected == nil {
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
@ -1414,7 +1634,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account {
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) *Account {
if sessionHash == "" {
return nil
}
@ -1446,14 +1666,14 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(account) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
@ -1477,7 +1697,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// Returns nil if no available account. The second return reports whether at
// least one candidate was filtered out solely because it lacks compact support
// (only meaningful when requireCompact=true).
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) {
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*Account, bool) {
var selected *Account
selectedCompactTier := -1
compactBlocked := false
@ -1492,11 +1712,11 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i
continue
}
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false)
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false, requiredCapability)
if fresh == nil {
continue
}
@ -1573,10 +1793,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false)
return s.selectAccountWithLoadAwareness(s.withOpenAIQuotaAutoPauseContext(ctx), groupID, sessionHash, requestedModel, excludedIDs, false, "")
}
func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) {
func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
@ -1593,7 +1813,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID)
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability)
if err != nil {
return nil, err
}
@ -1646,8 +1866,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if clearSticky {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
}
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) {
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if s.isOpenAIAccountRuntimeBlocked(account) {
@ -1691,15 +1911,12 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
if !isOpenAIAccountEligibleForRequest(ctx, acc, requestedModel, false, requiredCapability) {
continue
}
if s.isOpenAIAccountRuntimeBlocked(acc) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
continue
}
@ -1779,11 +1996,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false, requiredCapability)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
if fresh == nil {
continue
}
@ -1813,11 +2030,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
ordered = prioritizeOpenAICompactAccounts(ordered)
}
for _, acc := range ordered {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
if fresh == nil {
continue
}
@ -1858,11 +2075,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
candidates = prioritizeOpenAICompactAccounts(candidates)
}
for _, acc := range candidates {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
if fresh == nil {
continue
}
@ -1910,7 +2127,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account {
if account == nil {
return nil
}
@ -1924,7 +2141,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
fresh = current
}
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact, requiredCapability) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(fresh) {
@ -1933,12 +2150,12 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
return fresh
}
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact, requiredCapability) {
return nil
}
return account
@ -1948,7 +2165,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil {
return nil
}
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) {
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact, requiredCapability) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(latest) {
@ -4861,7 +5078,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
return
}
eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" &&
if eventType != "response.completed" && eventType != "response.done" && eventType != "response.failed" &&
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
return
}

View File

@ -18,7 +18,7 @@ type stubCodexRestrictionDetector struct {
result CodexClientRestrictionDetectionResult
}
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult {
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account, _ []string) CodexClientRestrictionDetectionResult {
return s.result
}
@ -52,7 +52,7 @@ func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) {
c.Request.Header.Set("User-Agent", "curl/8.0")
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}}
result := got.Detect(c, account)
result := got.Detect(c, account, nil)
require.True(t, result.Enabled)
require.True(t, result.Matched)
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)

View File

@ -2242,6 +2242,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require.Equal(t, 15, usage.OutputTokens)
require.Equal(t, 4, usage.CacheReadInputTokens)
// failed 事件在部分上游路径也会携带已消耗 usage应与 WS/passthrough 保持一致
svc.parseSSEUsage(`{"type":"response.failed","response":{"usage":{"input_tokens":17,"output_tokens":19,"input_tokens_details":{"cached_tokens":6}}}}`, usage)
require.Equal(t, 17, usage.InputTokens)
require.Equal(t, 19, usage.OutputTokens)
require.Equal(t, 6, usage.CacheReadInputTokens)
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
require.Equal(t, 21, usage.InputTokens)
require.Equal(t, 8, usage.OutputTokens)

View File

@ -413,6 +413,79 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
}
func TestAccountSupportsOpenAIEndpointCapability(t *testing.T) {
t.Run("OpenAI APIKey 默认兼容 chat 和 embeddings", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
})
t.Run("OpenAI OAuth 默认仅兼容 chat", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
})
t.Run("显式列表支持同时声明 chat 和 embeddings", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions", "embeddings"},
},
}
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
})
t.Run("显式列表只声明 chat 时不支持 embeddings", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions"},
},
}
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
})
t.Run("显式 map 支持单独关闭 chat 并开启 embeddings", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"openai_capabilities": map[string]any{
"chat_completions": false,
"embeddings": true,
},
},
}
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
})
t.Run("未知能力不应默认放行", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapability("unknown")))
})
}
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
require.Equal(t,
"https://image-upstream.example/v1/images/generations",

View File

@ -48,6 +48,46 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
}
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_QuotaAutoPausedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 77,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
"codex_5h_used_percent": 96.0,
"auto_pause_5h_threshold": 0.95,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_quota", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_quota", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "超过 5h 配额阈值的账号不应继续命中 previous_response_id 粘连")
// Auto-pause is transient, so the binding is preserved: the chain can resume on the
// same account once the quota window resets.
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_quota")
require.NoError(t, getErr)
require.Equal(t, account.ID, boundAccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
@ -268,6 +308,52 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_CapabilityMismatchKeepsSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(25)
account := Account{
ID: 31,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"openai_capabilities": []any{"chat_completions"},
},
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_capability", account.ID, time.Hour))
selection, err := svc.selectAccountByPreviousResponseIDForCapability(
ctx,
&groupID,
"resp_prev_capability",
"text-embedding-3-small",
nil,
OpenAIEndpointCapabilityEmbeddings,
false,
)
require.NoError(t, err)
require.Nil(t, selection)
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_capability")
require.NoError(t, getErr)
require.Equal(t, account.ID, boundAccountID)
}
func newOpenAIWSV2TestConfig() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true

View File

@ -369,7 +369,12 @@ func openAIWSEventMayContainToolCalls(eventType string) bool {
}
func openAIWSEventShouldParseUsage(eventType string) bool {
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
@ -2484,6 +2489,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
imageInputSize string
payloadBytes int
}
ingressSessionOriginalModel := ""
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
next, err := sjson.SetBytes(current, path, value)
@ -2547,12 +2553,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
originalModel := strings.TrimSpace(values[1].String())
modelMissing := originalModel == ""
if originalModel == "" {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"model is required in response.create payload",
nil,
)
// 入站 WS 长会话里,部分客户端只在第一轮 response.create 上声明
// model后续 turn 复用同一 session-level model。为避免因省略
// model 直接断开用户连接,这里回落到上一轮已通过校验的客户端模型,
// 并在下方写回上游 payload保证账号模型映射/fast policy/图片权限
// 仍按同一模型执行。
originalModel = ingressSessionOriginalModel
if originalModel == "" {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"model is required in response.create payload",
nil,
)
}
}
promptCacheKey := strings.TrimSpace(values[2].String())
previousResponseID := strings.TrimSpace(values[3].String())
@ -2572,7 +2587,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
if upstreamModel != originalModel {
if modelMissing || upstreamModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
if setErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
@ -2602,11 +2617,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
// single integration point for all WS ingress turns (first + follow-up
// frames flow through here).
//
// Model fallback: parseClientPayload above rejects any frame whose
// "model" field is missing (line ~2493-2500), so by the time we
// reach this point upstreamModel is always derived from a non-empty
// per-frame model. The capturedSessionModel fallback used in the
// passthrough adapter is therefore not needed in this path.
// Model fallback: first turn still requires model at the handler layer
// follow-up response.create frames may omit it and then reuse
// ingressSessionOriginalModel. We always write a concrete upstream model
// before evaluating policy, so whitelist / filter behavior remains stable.
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
if policyErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
@ -2635,6 +2649,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
normalized = policyApplied
ingressSessionOriginalModel = originalModel
return openAIWSClientPayload{
payloadRaw: normalized,
@ -3915,7 +3930,10 @@ func isOpenAIWSTokenEvent(eventType string) bool {
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
// 终止事件response.completed/done/failed/...)由 isOpenAIWSTerminalEvent 单独处理。
// 不能把它们当作 token event否则当上游没有可识别的 delta 时,
// firstTokenMs 会被填到终止时刻,等于把"总耗时"误报为"首 token 延迟"。
return false
}
func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
@ -3987,6 +4005,18 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
requestedModel string,
excludedIDs map[int64]struct{},
requireCompact bool,
) (*AccountSelectionResult, error) {
return s.selectAccountByPreviousResponseIDForCapability(ctx, groupID, previousResponseID, requestedModel, excludedIDs, "", requireCompact)
}
func (s *OpenAIGatewayService) selectAccountByPreviousResponseIDForCapability(
ctx context.Context,
groupID *int64,
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredCapability OpenAIEndpointCapability,
requireCompact bool,
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
@ -4027,12 +4057,41 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
if !account.SupportsOpenAIEndpointCapability(requiredCapability) {
return nil, nil
}
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
// Quota auto-pause must also gate the previous_response_id sticky path; otherwise an
// account over its 5h/7d threshold keeps serving the same response chain even though
// normal scheduling skips it. Pause is transient, so fall through to normal scheduling
// without deleting the binding (the window may reset before the next turn).
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
return nil, nil
}
if s.schedulerSnapshot != nil && s.accountRepo != nil {
latest, latestErr := s.accountRepo.GetByID(ctx, account.ID)
if latestErr != nil || latest == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
if shouldClearStickySession(latest, requestedModel) || !latest.IsOpenAI() || !latest.IsSchedulable() {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
return nil, nil
}
if !latest.SupportsOpenAIEndpointCapability(requiredCapability) {
return nil, nil
}
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, latest); paused {
return nil, nil
}
if s.isOpenAIAccountRuntimeBlocked(latest) {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
account = latest
}
if requireCompact && openAICompactSupportTier(account) == 0 {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil

View File

@ -39,6 +39,24 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
require.Equal(t, 4, usage.CacheReadInputTokens)
}
func TestOpenAIWSEventShouldParseUsageTerminalEvents(t *testing.T) {
t.Parallel()
for _, eventType := range []string{
"response.completed",
"response.done",
"response.failed",
"response.incomplete",
"response.cancelled",
"response.canceled",
} {
require.True(t, openAIWSEventShouldParseUsage(eventType), eventType)
require.True(t, openAIWSEventShouldParseUsage(" "+eventType+" "), eventType)
}
require.False(t, openAIWSEventShouldParseUsage("response.output_text.delta"))
require.False(t, openAIWSEventShouldParseUsage(""))
}
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)

View File

@ -164,6 +164,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCanOmitModel(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 115,
Name: "openai-ingress-omit-model",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"model_mapping": map[string]any{
"client-model": "gpt-5.1",
},
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"client-model","stream":false}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, firstEvent, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "resp_omit_model_1", gjson.GetBytes(firstEvent, "response.id").String())
writeCtx, cancelWrite = context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","stream":false,"previous_response_id":"resp_omit_model_1"}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead = context.WithTimeout(context.Background(), 3*time.Second)
_, secondEvent, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "resp_omit_model_2", gjson.GetBytes(secondEvent, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Len(t, captureConn.writes, 2)
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[0]), "model").String())
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[1]), "model").String())
require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String())
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
gin.SetMode(gin.TestMode)
@ -441,6 +575,124 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughHeadersUsePromptCacheAndTurnState(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
upstreamConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_headers","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPassthroughDialer: captureDialer,
}
account := &Account{
ID: 453,
Name: "openai-ingress-passthrough-headers",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
},
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
req.Header.Set(openAIWSTurnStateHeader, "turn-state-1")
req.Header.Set(openAIWSTurnMetadataHeader, "turn-meta-1")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "oauth-token", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"prompt_cache_key":"pcache_passthrough"}`))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "resp_passthrough_headers", gjson.GetBytes(event, "response.id").String())
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
select {
case serverErr := <-serverErrCh:
if serverErr != nil {
require.Contains(t, serverErr.Error(), "StatusNormalClosure")
}
case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时")
}
require.Equal(t, isolateOpenAISessionID(0, "pcache_passthrough"), captureDialer.lastHeaders.Get("session_id"))
require.Equal(t, "turn-state-1", captureDialer.lastHeaders.Get(openAIWSTurnStateHeader))
require.Equal(t, "turn-meta-1", captureDialer.lastHeaders.Get(openAIWSTurnMetadataHeader))
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -727,6 +727,70 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2_ResponseDoneUsageParsed(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.done","response":{"id":"resp_done_usage","model":"gpt-5.1","usage":{"input_tokens":13,"output_tokens":8,"input_tokens_details":{"cached_tokens":5},"cache_creation_input_tokens":2,"output_tokens_details":{"image_tokens":4}}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 32,
Name: "openai-ws-done",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hi"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_done_usage", result.RequestID)
require.Equal(t, 13, result.Usage.InputTokens)
require.Equal(t, 8, result.Usage.OutputTokens)
require.Equal(t, 5, result.Usage.CacheReadInputTokens)
require.Equal(t, 2, result.Usage.CacheCreationInputTokens)
require.Equal(t, 4, result.Usage.ImageOutputTokens)
}
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -0,0 +1,75 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestIsOpenAIWSTokenEvent_TerminalEventsExcluded 覆盖 isOpenAIWSTokenEvent 的回归用例。
// 重点验证终止事件response.completed / response.done不再被当作 token event
// 否则当上游没有可识别的 delta 时firstTokenMs 会被填到终止时刻,
// 等于把"总耗时"误报为"首 token 延迟"issue #2651
func TestIsOpenAIWSTokenEvent_TerminalEventsExcluded(t *testing.T) {
cases := []struct {
name string
eventType string
want bool
}{
{name: "empty", eventType: "", want: false},
{name: "whitespace_trimmed_empty", eventType: " ", want: false},
{name: "response.created", eventType: "response.created", want: false},
{name: "response.in_progress", eventType: "response.in_progress", want: false},
{name: "response.output_item.added", eventType: "response.output_item.added", want: false},
{name: "response.output_item.done", eventType: "response.output_item.done", want: false},
{name: "terminal_response.completed", eventType: "response.completed", want: false},
{name: "terminal_response.done", eventType: "response.done", want: false},
{name: "terminal_response.completed_padded", eventType: " response.completed ", want: false},
{name: "terminal_response.done_padded", eventType: " response.done ", want: false},
{name: "delta_text", eventType: "response.output_text.delta", want: true},
{name: "delta_audio_transcript", eventType: "response.audio_transcript.delta", want: true},
{name: "delta_function_call_arguments", eventType: "response.function_call_arguments.delta", want: true},
{name: "output_text_done", eventType: "response.output_text.done", want: true},
{name: "output_text_annotation_added", eventType: "response.output_text.annotation.added", want: true},
{name: "output_audio_done", eventType: "response.output_audio.done", want: true},
{name: "reasoning_summary_delta", eventType: "response.reasoning_summary_text.delta", want: true},
{name: "unrelated_event_error", eventType: "error", want: false},
{name: "unknown_event_without_match", eventType: "response.reasoning_summary_part.added", want: false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got := isOpenAIWSTokenEvent(tc.eventType)
require.Equal(t, tc.want, got, "isOpenAIWSTokenEvent(%q)", tc.eventType)
})
}
}
// TestIsOpenAIWSTokenEvent_DisjointWithTerminal 守护「token 事件集合与终止事件集合互斥」的不变量。
// firstTokenMs 的计算依赖于 isTokenEvent && !isTerminalEvent
// 若两者再次出现交集,则 issue #2651 描述的 latency 误报会重现。
func TestIsOpenAIWSTokenEvent_DisjointWithTerminal(t *testing.T) {
terminalEvents := []string{
"response.completed",
"response.done",
"response.failed",
"response.incomplete",
"response.cancelled",
"response.canceled",
}
for _, ev := range terminalEvents {
ev := ev
t.Run(ev, func(t *testing.T) {
require.True(t, isOpenAIWSTerminalEvent(ev), "expected terminal event %q to be classified as terminal", ev)
require.False(t, isOpenAIWSTokenEvent(ev), "terminal event %q must NOT be classified as token event (issue #2651)", ev)
})
}
}

View File

@ -25,6 +25,7 @@ type Usage struct {
OutputTokens int
CacheCreationInputTokens int
CacheReadInputTokens int
ImageOutputTokens int
}
type RelayResult struct {
@ -756,8 +757,21 @@ func parseUsageAndAccumulate(
}
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
if !inputResult.Exists() {
inputResult = gjson.GetBytes(message, "response.usage.prompt_tokens")
}
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
if !outputResult.Exists() {
outputResult = gjson.GetBytes(message, "response.usage.completion_tokens")
}
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
if !cachedResult.Exists() {
cachedResult = gjson.GetBytes(message, "response.usage.prompt_tokens_details.cached_tokens")
}
imageTokens := usageResult.Get("output_tokens_details.image_tokens").Int()
if imageTokens == 0 {
imageTokens = usageResult.Get("completion_tokens_details.image_tokens").Int()
}
inputTokens, inputOK := parseUsageIntField(inputResult, true)
outputTokens, outputOK := parseUsageIntField(outputResult, true)
@ -771,14 +785,18 @@ func parseUsageAndAccumulate(
return Usage{}
}
parsedUsage := Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheReadInputTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheCreationInputTokens: int(usageResult.Get("cache_creation_input_tokens").Int()),
CacheReadInputTokens: cachedTokens,
ImageOutputTokens: int(imageTokens),
}
state.usage.InputTokens += parsedUsage.InputTokens
state.usage.OutputTokens += parsedUsage.OutputTokens
state.usage.CacheCreationInputTokens += parsedUsage.CacheCreationInputTokens
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
state.usage.ImageOutputTokens += parsedUsage.ImageOutputTokens
return parsedUsage
}
@ -840,7 +858,7 @@ func isTerminalEvent(eventType string) bool {
func shouldParseUsage(eventType string) bool {
switch eventType {
case "response.completed", "response.done", "response.failed":
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false

View File

@ -300,20 +300,41 @@ func TestParseUsageAndEnrichCoverage(t *testing.T) {
require.Equal(t, 0, state.usage.OutputTokens)
require.Equal(t, 0, state.usage.CacheReadInputTokens)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1},"cache_creation_input_tokens":4,"output_tokens_details":{"image_tokens":3}}}}`), "response.completed", nil)
require.Equal(t, 2, state.usage.InputTokens)
require.Equal(t, 1, state.usage.OutputTokens)
require.Equal(t, 1, state.usage.CacheReadInputTokens)
require.Equal(t, 4, state.usage.CacheCreationInputTokens)
require.Equal(t, 3, state.usage.ImageOutputTokens)
result := &RelayResult{}
enrichResult(result, state, 5*time.Millisecond)
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
require.Equal(t, state.usage.CacheCreationInputTokens, result.Usage.CacheCreationInputTokens)
require.Equal(t, state.usage.ImageOutputTokens, result.Usage.ImageOutputTokens)
require.Equal(t, 5*time.Millisecond, result.Duration)
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
require.Equal(t, 2, state.usage.InputTokens)
enrichResult(nil, state, 0)
}
func TestParseUsageAndAccumulateAcceptsChatUsageAliases(t *testing.T) {
t.Parallel()
state := &relayState{}
got := parseUsageAndAccumulate(
state,
[]byte(`{"type":"response.done","response":{"usage":{"prompt_tokens":12,"completion_tokens":6,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"image_tokens":2}}}}`),
"response.done",
nil,
)
require.Equal(t, 12, got.InputTokens)
require.Equal(t, 6, got.OutputTokens)
require.Equal(t, 4, got.CacheReadInputTokens)
require.Equal(t, 2, got.ImageOutputTokens)
require.Equal(t, got, state.usage)
}
func TestEmitTurnCompleteCoverage(t *testing.T) {
t.Parallel()
@ -377,6 +398,23 @@ func TestIsTokenEventCoverageBranches(t *testing.T) {
require.True(t, isTokenEvent("response.done"))
}
func TestShouldParseUsageTerminalEvents(t *testing.T) {
t.Parallel()
for _, eventType := range []string{
"response.completed",
"response.done",
"response.failed",
"response.incomplete",
"response.cancelled",
"response.canceled",
} {
require.True(t, shouldParseUsage(eventType), eventType)
}
require.False(t, shouldParseUsage("response.output_text.delta"))
require.False(t, shouldParseUsage(""))
}
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
t.Parallel()

View File

@ -312,6 +312,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 usage metadata。
usageMeta.initFromFirstFrame(firstClientMessage)
promptCacheKey := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "prompt_cache_key").String())
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
@ -338,7 +339,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
turnState := ""
turnMetadata := ""
if c != nil {
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
@ -519,6 +526,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
ImageOutputTokens: turn.Usage.ImageOutputTokens,
},
Model: turn.RequestModel,
ServiceTier: usageMeta.serviceTier.Load(),
@ -593,6 +601,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
ImageOutputTokens: relayResult.Usage.ImageOutputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: usageMeta.serviceTier.Load(),

View File

@ -41,6 +41,11 @@ type OpsService struct {
// cleanupReloader 由 wire 在 OpsCleanupService 构造完成后通过 SetCleanupReloader 注入。
// 解耦避免 OpsService -> OpsCleanupService 的硬依赖cleanup 也读 settings会循环
cleanupReloader CleanupReloader
// quotaAutoPauseSink 由 wire 注入(通常是 SettingService.SetOpenAIQuotaAutoPauseSettings
// UpdateOpsAdvancedSettings 写入新配置后调用,把最新的 quota auto-pause 全局默认阈值
// 立即同步到调度热路径读取的内存缓存,避免下次请求才能感知新值。
quotaAutoPauseSink func(OpsOpenAIAccountQuotaAutoPauseSettings)
}
// CleanupReloader 由 OpsCleanupService 实现。
@ -57,6 +62,16 @@ func (s *OpsService) SetCleanupReloader(r CleanupReloader) {
s.cleanupReloader = r
}
// SetOpenAIQuotaAutoPauseSettingsSink 由 wire 注入,把最新的 quota auto-pause 全局默认
// 阈值 push 到调度热路径读取的内存缓存。同 SetCleanupReloader 的解耦目的:避免 OpsService
// 持有 *SettingService 引入循环依赖。
func (s *OpsService) SetOpenAIQuotaAutoPauseSettingsSink(sink func(OpsOpenAIAccountQuotaAutoPauseSettings)) {
if s == nil {
return
}
s.quotaAutoPauseSink = sink
}
func NewOpsService(
opsRepo OpsRepository,
settingRepo SettingRepository,

View File

@ -369,6 +369,7 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{
AggregationEnabled: false,
},
OpenAIAccountQuotaAutoPause: OpsOpenAIAccountQuotaAutoPauseSettings{},
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
@ -384,6 +385,8 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg == nil {
return
}
cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold5h = clampOpsQuotaAutoPauseThreshold(cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold5h)
cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold7d = clampOpsQuotaAutoPauseThreshold(cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold7d)
cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule)
if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = opsCleanupDefaultSchedule
@ -405,6 +408,16 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
}
}
func clampOpsQuotaAutoPauseThreshold(value float64) float64 {
if value <= 0 {
return 0
}
if value > 1 {
return 1
}
return value
}
func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil {
return errors.New("invalid config")
@ -477,6 +490,12 @@ func (s *OpsService) UpdateOpsAdvancedSettings(ctx context.Context, cfg *OpsAdva
if err := s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(raw)); err != nil {
return nil, err
}
// Push the new quota auto-pause settings straight into the in-memory cache that
// the OpenAI scheduling hot path reads, so the next request observes the new value
// without waiting for the background refresher's TTL.
if s.quotaAutoPauseSink != nil {
s.quotaAutoPauseSink(cfg.OpenAIAccountQuotaAutoPause)
}
// notify cleanup service to reload schedule/enabled.
if s.cleanupReloader != nil {

View File

@ -4,6 +4,9 @@ import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
@ -95,3 +98,64 @@ func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
}
}
func TestGetOpenAIQuotaAutoPauseSettings_ReadsDefaultsFromOpsAdvancedSettings(t *testing.T) {
repo := newRuntimeSettingRepoStub()
repo.values[SettingKeyOpsAdvancedSettings] = `{"openai_account_quota_auto_pause":{"default_threshold_5h":0.95,"default_threshold_7d":0.9}}`
svc := NewSettingService(repo, &config.Config{})
// Warm the in-memory cache synchronously so the assertion below is deterministic.
// GetOpenAIQuotaAutoPauseSettings is non-blocking on the hot path (returns the
// cached value, refreshes asynchronously); for tests and startup, Warm is the
// synchronous entry point that guarantees a populated cache.
settings := svc.WarmOpenAIQuotaAutoPauseSettings(context.Background())
if settings.DefaultThreshold5h != 0.95 {
t.Fatalf("DefaultThreshold5h = %v, want 0.95", settings.DefaultThreshold5h)
}
if settings.DefaultThreshold7d != 0.9 {
t.Fatalf("DefaultThreshold7d = %v, want 0.9", settings.DefaultThreshold7d)
}
// Subsequent Get must hit the warm cache and return the same value without any DB
// access — that's the hot-path invariant.
cached := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
if cached.DefaultThreshold5h != 0.95 || cached.DefaultThreshold7d != 0.9 {
t.Fatalf("cached read = %+v, want {0.95, 0.9}", cached)
}
}
// Hot-path invariant: a Get with cold cache must return immediately (zero defaults)
// rather than blocking on the DB. The async refresher will populate the cache for
// subsequent calls.
func TestGetOpenAIQuotaAutoPauseSettings_ColdCacheNonBlocking(t *testing.T) {
repo := newRuntimeSettingRepoStub()
repo.values[SettingKeyOpsAdvancedSettings] = `{"openai_account_quota_auto_pause":{"default_threshold_5h":0.7}}`
svc := NewSettingService(repo, &config.Config{})
start := time.Now()
settings := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
elapsed := time.Since(start)
if elapsed > 50*time.Millisecond {
t.Fatalf("cold-cache Get must be non-blocking, took %v", elapsed)
}
// Cold cache means we get zero defaults (the async refresh hasn't completed yet).
if settings.DefaultThreshold5h != 0 || settings.DefaultThreshold7d != 0 {
t.Fatalf("cold-cache Get = %+v, want zeroes", settings)
}
}
// Explicit cache write (e.g. from UpdateOpsAdvancedSettings) must be visible on the
// very next read without any DB roundtrip.
func TestSetOpenAIQuotaAutoPauseSettings_VisibleImmediately(t *testing.T) {
svc := NewSettingService(newRuntimeSettingRepoStub(), &config.Config{})
svc.SetOpenAIQuotaAutoPauseSettings(OpsOpenAIAccountQuotaAutoPauseSettings{
DefaultThreshold5h: 0.88,
DefaultThreshold7d: 0.77,
})
got := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
if got.DefaultThreshold5h != 0.88 || got.DefaultThreshold7d != 0.77 {
t.Fatalf("after Set, Get = %+v, want {0.88, 0.77}", got)
}
}

View File

@ -92,17 +92,23 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct {
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
DisplayAlertEvents bool `json:"display_alert_events"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
OpenAIAccountQuotaAutoPause OpsOpenAIAccountQuotaAutoPauseSettings `json:"openai_account_quota_auto_pause"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
DisplayAlertEvents bool `json:"display_alert_events"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
type OpsOpenAIAccountQuotaAutoPauseSettings struct {
DefaultThreshold5h float64 `json:"default_threshold_5h"`
DefaultThreshold7d float64 `json:"default_threshold_7d"`
}
type OpsDataRetentionSettings struct {

View File

@ -137,10 +137,32 @@ type cachedOpenAICodexUserAgent struct {
expiresAt int64 // unix nano
}
type cachedOpenAIQuotaAutoPauseSettings struct {
settings OpsOpenAIAccountQuotaAutoPauseSettings
expiresAt int64
}
const openAICodexUserAgentCacheTTL = 60 * time.Second
const openAICodexUserAgentErrorTTL = 5 * time.Second
const openAICodexUserAgentDBTimeout = 5 * time.Second
// cachedOpenAIAllowCodexPlugin Codex 插件放行开关缓存进程内缓存60s TTL
// IsOpenAIAllowClaudeCodeCodexPluginEnabled 在每个 codex_cli_only 账号的网关请求热路径上被调用,避免每次访问 DB。
type cachedOpenAIAllowCodexPlugin struct {
value bool
expiresAt int64 // unix nano
}
const openAIAllowCodexPluginCacheTTL = 60 * time.Second
const openAIAllowCodexPluginErrorTTL = 5 * time.Second
const openAIAllowCodexPluginDBTimeout = 5 * time.Second
const openAIQuotaAutoPauseSettingsCacheTTL = 60 * time.Second
const openAIQuotaAutoPauseSettingsErrorTTL = 5 * time.Second
const openAIQuotaAutoPauseSettingsDBTimeout = 5 * time.Second
const openAIQuotaAutoPauseSettingsRefreshKey = "openai_quota_auto_pause_settings"
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
@ -152,17 +174,28 @@ type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[i
// SettingService 系统设置服务
type SettingService struct {
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
antigravityUAVersionSF singleflight.Group
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
openAICodexUASF singleflight.Group
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
antigravityUAVersionSF singleflight.Group
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
openAICodexUASF singleflight.Group
openAIAllowCodexPluginCache atomic.Value // *cachedOpenAIAllowCodexPlugin
openAIAllowCodexPluginSF singleflight.Group
// openAIQuotaAutoPauseSettingsCache holds the most recently observed quota auto-pause
// settings. GetOpenAIQuotaAutoPauseSettings reads this atomic.Value on the request hot
// path without ever blocking on the DB; when the cached entry expires, a background
// goroutine refreshes it via openAIQuotaAutoPauseSettingsSF (stale-while-revalidate).
// This per-service field also gives tests natural isolation — each SettingService
// instance owns its own cache, no shared package-level state.
openAIQuotaAutoPauseSettingsCache atomic.Value // *cachedOpenAIQuotaAutoPauseSettings
openAIQuotaAutoPauseSettingsSF singleflight.Group
}
// DefaultPlatformQuotaSetting 单 platform 三档限额nil = 沿用上层0 = 显式禁用;>0 = 上限)
@ -1015,6 +1048,54 @@ func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string {
return fallback
}
// IsOpenAIAllowClaudeCodeCodexPluginEnabled 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认关闭)。
// 仅在调用方已确认账号 codex_cli_only 开启时读取,避免对非受限账号产生无谓查询。
// 使用进程内 atomic.Value 缓存60s TTL避免在每个网关请求热路径上访问 DB。
func (s *SettingService) IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx context.Context) bool {
if cached, ok := s.openAIAllowCodexPluginCache.Load().(*cachedOpenAIAllowCodexPlugin); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value
}
}
result, _, _ := s.openAIAllowCodexPluginSF.Do("openai_allow_codex_plugin_enabled", func() (any, error) {
if cached, ok := s.openAIAllowCodexPluginCache.Load().(*cachedOpenAIAllowCodexPlugin); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAllowCodexPluginDBTimeout)
defer cancel()
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAIAllowClaudeCodeCodexPlugin)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
// 设置不存在 → 默认关闭,正常 TTL 缓存
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
value: false,
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
})
return false, nil
}
slog.Warn("failed to get openai_allow_claude_code_codex_plugin setting", "error", err)
// DB 错误 → 安全默认关闭,短 TTL 快速重试
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
value: false,
expiresAt: time.Now().Add(openAIAllowCodexPluginErrorTTL).UnixNano(),
})
return false, nil
}
enabled := value == "true"
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
value: enabled,
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
})
return enabled, nil
})
if val, ok := result.(bool); ok {
return val
}
return false
}
// SetOnUpdateCallback sets a callback function to be called when settings are updated
// This is used for cache invalidation (e.g., HTML cache in frontend server)
func (s *SettingService) SetOnUpdateCallback(callback func()) {
@ -1816,6 +1897,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
updates[SettingKeyOpenAIAllowClaudeCodeCodexPlugin] = strconv.FormatBool(settings.OpenAIAllowClaudeCodeCodexPlugin)
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
@ -1965,9 +2047,25 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
enabled: settings.OpenAIAdvancedSchedulerEnabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
// Invalidate the quota auto-pause cache and let the next read trigger a fresh load.
// We can't know from here whether ops_advanced_settings was also touched, so be
// defensive: store an expired entry — GetOpenAIQuotaAutoPauseSettings will serve
// stale and kick off an async refresh, never blocking the request that follows.
s.openAIQuotaAutoPauseSettingsSF.Forget(openAIQuotaAutoPauseSettingsRefreshKey)
if cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings); cached != nil {
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
settings: cached.settings,
expiresAt: 0,
})
}
if s.cfg != nil {
s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP)
}
s.openAIAllowCodexPluginSF.Forget("openai_allow_codex_plugin_enabled")
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
value: settings.OpenAIAllowClaudeCodeCodexPlugin,
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
})
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
@ -3233,6 +3331,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
result.OpenAIAllowClaudeCodeCodexPlugin = settings[SettingKeyOpenAIAllowClaudeCodeCodexPlugin] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
@ -4380,6 +4479,106 @@ func (s *SettingService) GetClaudeCodeVersionBounds(ctx context.Context) (min, m
return b.min, b.max
}
// GetOpenAIQuotaAutoPauseSettings returns the current global default quota auto-pause
// settings. It is invoked on the OpenAI scheduling hot path (once per request) and is
// therefore designed to never block on the DB:
//
// - Fresh cached value → returned immediately.
// - Stale or empty cache → the last known value is returned, and a background
// goroutine refreshes the cache via singleflight (stale-while-revalidate).
// - First call with no cache yet → zero defaults are returned and the same async
// refresh is kicked off; the next call gets the freshly populated value.
//
// Callers that need the freshly persisted value synchronously (tests, post-update
// confirmation, optional startup warm-up) should call WarmOpenAIQuotaAutoPauseSettings.
func (s *SettingService) GetOpenAIQuotaAutoPauseSettings(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
if s == nil {
return OpsOpenAIAccountQuotaAutoPauseSettings{}
}
cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings)
now := time.Now().UnixNano()
if cached != nil && now < cached.expiresAt {
return cached.settings
}
// Stale or unset: trigger background refresh without blocking this request.
// singleflight.DoChan dedupes concurrent refreshes; we deliberately ignore the
// returned channel — the result is observable via the atomic cache.
s.openAIQuotaAutoPauseSettingsSF.DoChan(openAIQuotaAutoPauseSettingsRefreshKey, func() (any, error) {
s.refreshOpenAIQuotaAutoPauseSettings(context.Background())
return nil, nil
})
if cached != nil {
return cached.settings // serve stale value while revalidating
}
return OpsOpenAIAccountQuotaAutoPauseSettings{}
}
// WarmOpenAIQuotaAutoPauseSettings synchronously loads the quota auto-pause settings
// into the in-memory cache. Useful for application startup (so the first request hits
// a warm cache) and for tests that need deterministic reads immediately after
// constructing the service.
func (s *SettingService) WarmOpenAIQuotaAutoPauseSettings(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
if s == nil {
return OpsOpenAIAccountQuotaAutoPauseSettings{}
}
s.refreshOpenAIQuotaAutoPauseSettings(ctx)
cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings)
if cached == nil {
return OpsOpenAIAccountQuotaAutoPauseSettings{}
}
return cached.settings
}
// refreshOpenAIQuotaAutoPauseSettings reads the latest settings from the DB and stores
// them into the in-memory cache. On error it stores the prior value (or zero defaults
// if nothing is cached yet) with the shorter error TTL so the next refresh comes
// sooner. Always uses its own timeout-bounded context to keep refresh latency
// predictable regardless of the caller.
func (s *SettingService) refreshOpenAIQuotaAutoPauseSettings(ctx context.Context) {
if s == nil || s.settingRepo == nil {
return
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIQuotaAutoPauseSettingsDBTimeout)
defer cancel()
settings := OpsOpenAIAccountQuotaAutoPauseSettings{}
ttl := openAIQuotaAutoPauseSettingsCacheTTL
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpsAdvancedSettings)
if err == nil {
cfg := defaultOpsAdvancedSettings()
if strings.TrimSpace(raw) != "" {
if jsonErr := json.Unmarshal([]byte(raw), cfg); jsonErr == nil {
normalizeOpsAdvancedSettings(cfg)
}
}
settings = cfg.OpenAIAccountQuotaAutoPause
} else if !errors.Is(err, ErrSettingNotFound) {
// Real error: keep serving prior value but refresh sooner.
if prior, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings); prior != nil {
settings = prior.settings
}
ttl = openAIQuotaAutoPauseSettingsErrorTTL
}
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
settings: settings,
expiresAt: time.Now().Add(ttl).UnixNano(),
})
}
// SetOpenAIQuotaAutoPauseSettings writes the given settings directly into the in-memory
// cache. Called from settings-write code paths so that the next read reflects the new
// value immediately, without waiting for the background refresh.
func (s *SettingService) SetOpenAIQuotaAutoPauseSettings(settings OpsOpenAIAccountQuotaAutoPauseSettings) {
if s == nil {
return
}
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
settings: settings,
expiresAt: time.Now().Add(openAIQuotaAutoPauseSettingsCacheTTL).UnixNano(),
})
}
// GetRectifierSettings 获取请求整流器配置
func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings)

View File

@ -0,0 +1,55 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type allowClaudeCodeSettingRepoStub struct{ values map[string]string }
func (s *allowClaudeCodeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unused")
}
func (s *allowClaudeCodeSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *allowClaudeCodeSettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unused")
}
func (s *allowClaudeCodeSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unused")
}
func (s *allowClaudeCodeSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unused")
}
func (s *allowClaudeCodeSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unused")
}
func (s *allowClaudeCodeSettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unused")
}
func TestSettingService_IsOpenAIAllowClaudeCodeCodexPluginEnabled(t *testing.T) {
t.Run("默认关闭(设置缺失)", func(t *testing.T) {
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{}}, &config.Config{})
require.False(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
})
t.Run("值为 true 时开启", func(t *testing.T) {
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{
SettingKeyOpenAIAllowClaudeCodeCodexPlugin: "true",
}}, &config.Config{})
require.True(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
})
t.Run("值非 true 时关闭", func(t *testing.T) {
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{
SettingKeyOpenAIAllowClaudeCodeCodexPlugin: "false",
}}, &config.Config{})
require.False(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
})
}

View File

@ -195,6 +195,7 @@ type SystemSettings struct {
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control默认 false
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent空值使用内置默认
OpenAIAllowClaudeCodeCodexPlugin bool // 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认 false
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟

View File

@ -17,6 +17,12 @@ import (
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrNoUpdateAvailable = infraerrors.Conflict("ALREADY_UP_TO_DATE", "no update available; current version is latest")
)
const (
@ -146,7 +152,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
}
if !info.HasUpdate {
return fmt.Errorf("no update available")
return ErrNoUpdateAvailable
}
// Find matching archive and checksum for current platform

View File

@ -0,0 +1,64 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type updateServiceCacheStub struct {
data string
}
func (s *updateServiceCacheStub) GetUpdateInfo(context.Context) (string, error) {
if s.data == "" {
return "", errors.New("cache miss")
}
return s.data, nil
}
func (s *updateServiceCacheStub) SetUpdateInfo(_ context.Context, data string, _ time.Duration) error {
s.data = data
return nil
}
type updateServiceGitHubClientStub struct {
release *GitHubRelease
}
func (s *updateServiceGitHubClientStub) FetchLatestRelease(context.Context, string) (*GitHubRelease, error) {
return s.release, nil
}
func (s *updateServiceGitHubClientStub) DownloadFile(context.Context, string, string, int64) error {
panic("DownloadFile should not be called when no update is available")
}
func (s *updateServiceGitHubClientStub) FetchChecksumFile(context.Context, string) ([]byte, error) {
panic("FetchChecksumFile should not be called when no update is available")
}
func TestUpdateServicePerformUpdateNoUpdateReturnsSentinel(t *testing.T) {
svc := NewUpdateService(
&updateServiceCacheStub{},
&updateServiceGitHubClientStub{
release: &GitHubRelease{
TagName: "v0.1.132",
Name: "v0.1.132",
},
},
"0.1.132",
"release",
)
err := svc.PerformUpdate(context.Background())
require.Error(t, err)
require.True(t, errors.Is(err, ErrNoUpdateAvailable))
require.ErrorIs(t, err, ErrNoUpdateAvailable)
}

View File

@ -399,6 +399,46 @@ func ProvideBackupService(
return svc
}
// ProvideOpsService constructs OpsService and wires the SettingService-backed quota
// auto-pause cache sink. Mirrors the SetCleanupReloader pattern: OpsService doesn't
// hold a *SettingService reference, but wire injects a tiny callback so writes to
// ops_advanced_settings immediately propagate into the scheduler hot-path cache.
func ProvideOpsService(
opsRepo OpsRepository,
settingRepo SettingRepository,
cfg *config.Config,
accountRepo AccountRepository,
userRepo UserRepository,
concurrencyService *ConcurrencyService,
gatewayService *GatewayService,
openAIGatewayService *OpenAIGatewayService,
geminiCompatService *GeminiMessagesCompatService,
antigravityGatewayService *AntigravityGatewayService,
systemLogSink *OpsSystemLogSink,
settingService *SettingService,
) *OpsService {
svc := NewOpsService(
opsRepo,
settingRepo,
cfg,
accountRepo,
userRepo,
concurrencyService,
gatewayService,
openAIGatewayService,
geminiCompatService,
antigravityGatewayService,
systemLogSink,
)
if settingService != nil {
svc.SetOpenAIQuotaAutoPauseSettingsSink(settingService.SetOpenAIQuotaAutoPauseSettings)
// Optional warm-up so the first scheduled request after process start observes
// a populated cache rather than zero defaults. Best-effort, sync-bounded.
settingService.WarmOpenAIQuotaAutoPauseSettings(context.Background())
}
return svc
}
// ProvideSettingService wires SettingService with group reader and proxy repo.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg)
@ -486,7 +526,7 @@ var ProviderSet = wire.NewSet(
ProvideBackupService,
NewHealthService,
ProvideOpsSystemLogSink,
NewOpsService,
ProvideOpsService,
ProvideOpsMetricsCollector,
ProvideOpsAggregationService,
ProvideOpsAlertEvaluatorService,

View File

@ -0,0 +1,16 @@
-- 为已持久化的 Antigravity model_mapping 添加 claude-opus-4-8。
--
-- 未持久化 model_mapping 的账号会直接使用 DefaultAntigravityModelMapping
-- 因此这里只需要回填已有映射对象。
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{model_mapping,claude-opus-4-8}',
'"claude-opus-4-8"'::jsonb,
true
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
AND jsonb_typeof(credentials->'model_mapping') = 'object'
AND credentials->'model_mapping'->>'claude-opus-4-8' IS NULL;

View File

@ -778,9 +778,15 @@ export interface OpsAlertRuntimeSettings {
thresholds: OpsMetricThresholds // 指标阈值配置
}
export interface OpsOpenAIAccountQuotaAutoPauseSettings {
default_threshold_5h: number // 0~10 表示不启用全局默认 5h 阈值
default_threshold_7d: number // 0~10 表示不启用全局默认 7d 阈值
}
export interface OpsAdvancedSettings {
data_retention: OpsDataRetentionSettings
aggregation: OpsAggregationSettings
openai_account_quota_auto_pause: OpsOpenAIAccountQuotaAutoPauseSettings
ignore_count_tokens_errors: boolean
ignore_context_canceled: boolean
ignore_no_available_accounts: boolean

View File

@ -561,6 +561,7 @@ export interface SystemSettings {
rewrite_message_cache_control: boolean;
antigravity_user_agent_version: string;
openai_codex_user_agent: string;
openai_allow_claude_code_codex_plugin: boolean;
web_search_emulation_enabled?: boolean;
// Payment configuration
@ -794,6 +795,7 @@ export interface UpdateSettingsRequest {
rewrite_message_cache_control?: boolean;
antigravity_user_agent_version?: string;
openai_codex_user_agent?: string;
openai_allow_claude_code_codex_plugin?: boolean;
// Payment configuration
payment_enabled?: boolean;
risk_control_enabled?: boolean;

View File

@ -222,6 +222,8 @@ const formatScopeName = (scope: string): string => {
// Claude
'claude-opus-4-6': 'COpus46',
'claude-opus-4-6-thinking': 'COpus46T',
'claude-opus-4-7': 'COpus47',
'claude-opus-4-8': 'COpus48',
'claude-sonnet-4-6': 'CSon46',
'claude-sonnet-4-5': 'CSon45',
'claude-sonnet-4-5-thinking': 'CSon45T',

View File

@ -697,6 +697,7 @@ const antigravityClaudeUsageFromAPI = computed(() =>
getAntigravityUsageFromAPI([
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
'claude-sonnet-4-6', 'claude-opus-4-6', 'claude-opus-4-6-thinking',
'claude-opus-4-7', 'claude-opus-4-8',
])
)

View File

@ -742,6 +742,50 @@
</div>
</div>
<!-- OpenAI OAuth: 额外放行 Claude Code Codex 插件 -->
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<label
id="bulk-edit-openai-codex-allow-claude-code-label"
class="input-label mb-0"
for="bulk-edit-openai-codex-allow-claude-code-enabled"
>
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}
</label>
<input
v-model="enableCodexCLIOnlyAllowClaudeCode"
id="bulk-edit-openai-codex-allow-claude-code-enabled"
type="checkbox"
aria-controls="bulk-edit-openai-codex-allow-claude-code"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
/>
</div>
<div
id="bulk-edit-openai-codex-allow-claude-code"
:class="!enableCodexCLIOnlyAllowClaudeCode && 'pointer-events-none opacity-50'"
>
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
</p>
<button
id="bulk-edit-openai-codex-allow-claude-code-toggle"
type="button"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<!-- OpenAI API Key WS mode -->
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
@ -1219,6 +1263,7 @@ const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
const enableOpenAIAPIKeyWSMode = ref(false)
const enableCodexCLIOnly = ref(false)
const enableCodexCLIOnlyAllowClaudeCode = ref(false)
const enableOpenAICompactMode = ref(false)
const enableOpenAICompactModelMapping = ref(false)
const enableRpmLimit = ref(false)
@ -1246,6 +1291,7 @@ const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openAICompactModelMappings = ref<ModelMapping[]>([])
const rpmLimitEnabled = ref(false)
@ -1496,6 +1542,11 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
extra.codex_cli_only = codexCLIOnlyEnabled.value
}
if (enableCodexCLIOnlyAllowClaudeCode.value) {
const extra = ensureExtra()
extra.codex_cli_only_allowed_clients = codexCLIOnlyAllowClaudeCodeEnabled.value ? ['claude_code'] : []
}
if (enableOpenAICompactMode.value) {
const extra = ensureExtra()
extra.openai_compact_mode = openAICompactMode.value
@ -1602,6 +1653,7 @@ const handleSubmit = async () => {
enableOpenAIWSMode.value ||
enableOpenAIAPIKeyWSMode.value ||
enableCodexCLIOnly.value ||
enableCodexCLIOnlyAllowClaudeCode.value ||
enableOpenAICompactMode.value ||
enableOpenAICompactModelMapping.value ||
enableRpmLimit.value ||
@ -1704,6 +1756,7 @@ watch(
enableOpenAIWSMode.value = false
enableOpenAIAPIKeyWSMode.value = false
enableCodexCLIOnly.value = false
enableCodexCLIOnlyAllowClaudeCode.value = false
enableOpenAICompactMode.value = false
enableOpenAICompactModelMapping.value = false
enableRpmLimit.value = false
@ -1727,6 +1780,7 @@ watch(
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
codexCLIOnlyAllowClaudeCodeEnabled.value = false
openAICompactMode.value = 'auto'
openAICompactModelMappings.value = []
rpmLimitEnabled.value = false

View File

@ -2746,6 +2746,32 @@
/>
</button>
</div>
<div
v-if="codexCLIOnlyEnabled"
class="mt-4 flex items-center justify-between border-l-2 border-gray-200 pl-4 dark:border-dark-600"
>
<div>
<label class="input-label mb-0">{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
</p>
</div>
<button
type="button"
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<!-- OpenAI Compact 能力配置 -->
@ -2790,7 +2816,7 @@
<!-- OpenAI APIKey Responses API support mode -->
<div
v-if="form.platform === 'openai' && accountCategory === 'apikey'"
class="border-t border-gray-200 pt-4 dark:border-dark-600"
class="space-y-4 border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between gap-4">
<div>
@ -2803,10 +2829,38 @@
<Select
v-model="openAIResponsesMode"
:options="openAIResponsesModeOptions"
:disabled="!openAITextGenerationCapabilityEnabled"
data-testid="openai-responses-mode-select"
/>
</div>
</div>
<p
v-if="!openAITextGenerationCapabilityEnabled"
class="rounded-lg bg-amber-50 px-3 py-2 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300"
data-testid="openai-responses-mode-not-applicable"
>
{{ t('admin.accounts.openai.responsesModeTextDisabledHint') }}
</p>
<div>
<label class="input-label mb-2 block">{{ t('admin.accounts.openai.endpointCapabilities') }}</label>
<div class="grid grid-cols-1 gap-2 sm:grid-cols-2">
<label
v-for="option in openAIEndpointCapabilityOptions"
:key="option.value"
class="flex cursor-pointer items-center gap-2 rounded-lg border border-gray-200 px-3 py-2 text-sm dark:border-dark-600"
>
<input
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
:data-testid="`openai-endpoint-capability-${option.value}`"
:checked="openAIEndpointCapabilities.includes(option.value)"
@change="toggleOpenAIEndpointCapability(option.value, $event)"
/>
<span class="text-gray-700 dark:text-gray-200">{{ option.label }}</span>
</label>
</div>
<p class="input-hint">{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}</p>
</div>
</div>
<div>
@ -3287,7 +3341,8 @@ import type {
CreateAccountRequest,
CodexSessionImportMessage,
OpenAICompactMode,
OpenAIResponsesMode
OpenAIResponsesMode,
OpenAIEndpointCapability
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
@ -3466,9 +3521,11 @@ const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
const openAIEndpointCapabilities = ref<OpenAIEndpointCapability[]>(['chat_completions', 'embeddings'])
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
const webSearchEmulationMode = ref('default')
const webSearchGlobalEnabled = ref(false)
@ -3534,6 +3591,58 @@ const openAIResponsesModeOptions = computed(() => [
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
])
const openAITextEndpointCapabilityLabel = computed(() => {
if (openAIResponsesMode.value === 'force_responses') {
return t('admin.accounts.openai.capabilityResponses')
}
if (openAIResponsesMode.value === 'force_chat_completions') {
return t('admin.accounts.openai.capabilityChatCompletions')
}
return t('admin.accounts.openai.capabilityTextAuto')
})
const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [
{ value: 'chat_completions', label: openAITextEndpointCapabilityLabel.value },
{ value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') }
])
const openAITextGenerationCapabilityEnabled = computed(() =>
openAIEndpointCapabilities.value.includes('chat_completions')
)
const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => {
const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings']
const selected = allowed.filter((value) => values.includes(value))
return selected.length > 0 ? selected : allowed
}
const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => {
if (openAIEndpointCapabilities.value.includes(capability)) {
if (openAIEndpointCapabilities.value.length <= 1) {
const input = event?.target as HTMLInputElement | null
if (input) input.checked = true
return
}
openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter(
(value) => value !== capability
)
if (!openAITextGenerationCapabilityEnabled.value) {
openAIResponsesMode.value = 'auto'
}
return
}
openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([
...openAIEndpointCapabilities.value,
capability
])
}
const applyOpenAIEndpointCapabilities = (credentials: Record<string, unknown>) => {
const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value)
if (capabilities.length === 2) {
delete credentials.openai_capabilities
return
}
credentials.openai_capabilities = capabilities
}
function buildAntigravityExtra(): Record<string, unknown> | undefined {
const extra: Record<string, unknown> = {}
@ -3847,9 +3956,11 @@ watch(
}
if (newPlatform !== 'openai') {
openaiPassthroughEnabled.value = false
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
codexCLIOnlyAllowClaudeCodeEnabled.value = false
}
if (newPlatform !== 'anthropic') {
anthropicPassthroughEnabled.value = false
@ -3870,6 +3981,7 @@ watch(
([category, platform]) => {
if (platform === 'openai' && category !== 'oauth-based') {
codexCLIOnlyEnabled.value = false
codexCLIOnlyAllowClaudeCodeEnabled.value = false
}
if (platform !== 'anthropic' || category !== 'apikey') {
anthropicPassthroughEnabled.value = false
@ -4268,9 +4380,11 @@ const resetForm = () => {
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
openAIResponsesMode.value = 'auto'
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
codexCLIOnlyAllowClaudeCodeEnabled.value = false
anthropicPassthroughEnabled.value = false
webSearchEmulationMode.value = 'default'
// Reset quota control state
@ -4353,13 +4467,26 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
} else {
delete extra.codex_cli_only
}
if (
accountCategory.value === 'oauth-based' &&
codexCLIOnlyEnabled.value &&
codexCLIOnlyAllowClaudeCodeEnabled.value
) {
extra.codex_cli_only_allowed_clients = ['claude_code']
} else {
delete extra.codex_cli_only_allowed_clients
}
if (openAICompactMode.value !== 'auto') {
extra.openai_compact_mode = openAICompactMode.value
} else {
delete extra.openai_compact_mode
}
if (accountCategory.value === 'apikey' && openAIResponsesMode.value !== 'auto') {
if (
accountCategory.value === 'apikey' &&
openAITextGenerationCapabilityEnabled.value &&
openAIResponsesMode.value !== 'auto'
) {
extra.openai_responses_mode = openAIResponsesMode.value
} else {
delete extra.openai_responses_mode
@ -4689,6 +4816,7 @@ const handleSubmit = async () => {
}
}
if (form.platform === 'openai') {
applyOpenAIEndpointCapabilities(credentials)
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping
@ -4811,6 +4939,9 @@ const createAccountAndFinish = async (
}
}
if (platform === 'openai') {
if (type === 'apikey') {
applyOpenAIEndpointCapabilities(credentials)
}
const compactModelMapping = buildOpenAICompactModelMapping()
if (compactModelMapping) {
credentials.compact_model_mapping = compactModelMapping

View File

@ -1439,7 +1439,7 @@
<!-- OpenAI APIKey Responses API support mode -->
<div
v-if="account?.platform === 'openai' && account?.type === 'apikey'"
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-3"
class="space-y-4 border-t border-gray-200 pt-4 dark:border-dark-600"
>
<div class="flex items-center justify-between gap-4">
<div>
@ -1452,13 +1452,44 @@
<Select
v-model="openAIResponsesMode"
:options="openAIResponsesModeOptions"
:disabled="!openAITextGenerationCapabilityEnabled"
data-testid="openai-responses-mode-select"
/>
</div>
</div>
<div class="rounded-lg bg-gray-50 px-3 py-2 text-xs text-gray-600 dark:bg-dark-700 dark:text-gray-300">
<div
v-if="openAITextGenerationCapabilityEnabled"
class="rounded-lg bg-gray-50 px-3 py-2 text-xs text-gray-600 dark:bg-dark-700 dark:text-gray-300"
>
<span class="font-medium">{{ t(openAIResponsesStatusKey) }}</span>
</div>
<div
v-else
class="rounded-lg bg-amber-50 px-3 py-2 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300"
data-testid="openai-responses-mode-not-applicable"
>
{{ t('admin.accounts.openai.responsesModeTextDisabledHint') }}
</div>
<div>
<label class="input-label mb-2 block">{{ t('admin.accounts.openai.endpointCapabilities') }}</label>
<div class="grid grid-cols-1 gap-2 sm:grid-cols-2">
<label
v-for="option in openAIEndpointCapabilityOptions"
:key="option.value"
class="flex cursor-pointer items-center gap-2 rounded-lg border border-gray-200 px-3 py-2 text-sm dark:border-dark-600"
>
<input
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
:data-testid="`openai-endpoint-capability-${option.value}`"
:checked="openAIEndpointCapabilities.includes(option.value)"
@change="toggleOpenAIEndpointCapability(option.value, $event)"
/>
<span class="text-gray-700 dark:text-gray-200">{{ option.label }}</span>
</label>
</div>
<p class="input-hint">{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}</p>
</div>
</div>
<!-- Anthropic API Key 自动透传开关 -->
@ -1642,6 +1673,32 @@
/>
</button>
</div>
<div
v-if="codexCLIOnlyEnabled"
class="mt-4 flex items-center justify-between border-l-2 border-gray-200 pl-4 dark:border-dark-600"
>
<div>
<label class="input-label mb-0">{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
</p>
</div>
<button
type="button"
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<div
@ -1730,6 +1787,84 @@
</div>
</div>
<div
v-if="account?.platform === 'openai'"
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
>
<div class="space-y-2">
<div class="flex items-center justify-between">
<label class="input-label mb-0">{{ t('admin.accounts.autoPause5hDisabled') }}</label>
<button
type="button"
@click="autoPause5hDisabled = !autoPause5hDisabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
autoPause5hDisabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
data-testid="auto-pause-5h-disabled"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
autoPause5hDisabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<p class="input-hint">{{ t('admin.accounts.autoPauseDisabledHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.autoPause5hThreshold') }}</label>
<input
v-model.number="autoPause5hThreshold"
type="number"
min="0"
max="100"
step="0.1"
class="input"
:disabled="autoPause5hDisabled"
data-testid="auto-pause-5h-threshold"
/>
<p class="input-hint">{{ t('admin.accounts.autoPauseThresholdHint') }}</p>
</div>
<div class="space-y-2">
<div class="flex items-center justify-between">
<label class="input-label mb-0">{{ t('admin.accounts.autoPause7dDisabled') }}</label>
<button
type="button"
@click="autoPause7dDisabled = !autoPause7dDisabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
autoPause7dDisabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
data-testid="auto-pause-7d-disabled"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
autoPause7dDisabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<p class="input-hint">{{ t('admin.accounts.autoPauseDisabledHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.autoPause7dThreshold') }}</label>
<input
v-model.number="autoPause7dThreshold"
type="number"
min="0"
max="100"
step="0.1"
class="input"
:disabled="autoPause7dDisabled"
data-testid="auto-pause-7d-threshold"
/>
<p class="input-hint">{{ t('admin.accounts.autoPauseThresholdHint') }}</p>
</div>
</div>
<!-- 配额控制 (Anthropic OAuth/SetupToken: 亲和 + 窗口费用 + 会话 + RPM ) -->
<div
v-if="account?.platform === 'anthropic' && (account?.type === 'oauth' || account?.type === 'setup-token')"
@ -2245,7 +2380,15 @@ import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode, OpenAIResponsesMode } from '@/types'
import type {
Account,
Proxy,
AdminGroup,
CheckMixedChannelResponse,
OpenAICompactMode,
OpenAIResponsesMode,
OpenAIEndpointCapability
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@ -2382,6 +2525,10 @@ const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(false)
const autoPause5hThreshold = ref<number | null>(null)
const autoPause7dThreshold = ref<number | null>(null)
const autoPause5hDisabled = ref(false)
const autoPause7dDisabled = ref(false)
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
@ -2433,9 +2580,11 @@ const customBaseUrl = ref('')
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref<OpenAICompactMode>('auto')
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
const openAIEndpointCapabilities = ref<OpenAIEndpointCapability[]>(['chat_completions', 'embeddings'])
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
type CodexImageGenerationBridgeMode = 'inherit' | 'enabled' | 'disabled'
const codexImageGenerationBridgeMode = ref<CodexImageGenerationBridgeMode>('inherit')
const anthropicPassthroughEnabled = ref(false)
@ -2539,6 +2688,85 @@ const openAIResponsesModeOptions = computed(() => [
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
])
const openAITextEndpointCapabilityLabel = computed(() => {
if (openAIResponsesMode.value === 'force_responses') {
return t('admin.accounts.openai.capabilityResponses')
}
if (openAIResponsesMode.value === 'force_chat_completions') {
return t('admin.accounts.openai.capabilityChatCompletions')
}
const extra = props.account?.extra as Record<string, unknown> | undefined
if (extra?.openai_responses_supported === true) {
return t('admin.accounts.openai.capabilityResponsesAuto')
}
if (extra?.openai_responses_supported === false) {
return t('admin.accounts.openai.capabilityChatCompletionsAuto')
}
return t('admin.accounts.openai.capabilityTextAuto')
})
const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [
{ value: 'chat_completions', label: openAITextEndpointCapabilityLabel.value },
{ value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') }
])
const openAITextGenerationCapabilityEnabled = computed(() =>
openAIEndpointCapabilities.value.includes('chat_completions')
)
const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => {
const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings']
const selected = allowed.filter((value) => values.includes(value))
return selected.length > 0 ? selected : allowed
}
const readOpenAIEndpointCapabilities = (credentials?: Record<string, unknown>): OpenAIEndpointCapability[] => {
const raw = credentials?.openai_capabilities
if (Array.isArray(raw)) {
return normalizeOpenAIEndpointCapabilities(
raw.filter((value): value is OpenAIEndpointCapability =>
value === 'chat_completions' || value === 'embeddings'
)
)
}
if (raw !== null && typeof raw === 'object') {
const capabilityMap = raw as Record<string, unknown>
return normalizeOpenAIEndpointCapabilities(
openAIEndpointCapabilityOptions.value
.map((option) => option.value)
.filter((value) => capabilityMap[value] === true)
)
}
return ['chat_completions', 'embeddings']
}
const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => {
if (openAIEndpointCapabilities.value.includes(capability)) {
if (openAIEndpointCapabilities.value.length <= 1) {
const input = event?.target as HTMLInputElement | null
if (input) input.checked = true
return
}
openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter(
(value) => value !== capability
)
if (!openAITextGenerationCapabilityEnabled.value) {
openAIResponsesMode.value = 'auto'
}
return
}
openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([
...openAIEndpointCapabilities.value,
capability
])
}
const applyOpenAIEndpointCapabilities = (credentials: Record<string, unknown>) => {
const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value)
if (capabilities.length === 2) {
delete credentials.openai_capabilities
return
}
credentials.openai_capabilities = capabilities
}
const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => {
if (mode === 'force_responses' || mode === 'force_chat_completions') {
return mode
@ -2716,18 +2944,24 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load mixed scheduling setting (only for antigravity accounts)
mixedScheduling.value = false
allowOverages.value = false
const extra = newAccount.extra as Record<string, unknown> | undefined
mixedScheduling.value = extra?.mixed_scheduling === true
allowOverages.value = extra?.allow_overages === true
const extra = newAccount.extra as Record<string, unknown> | undefined
mixedScheduling.value = extra?.mixed_scheduling === true
allowOverages.value = extra?.allow_overages === true
autoPause5hThreshold.value = typeof extra?.auto_pause_5h_threshold === 'number' ? extra.auto_pause_5h_threshold * 100 : null
autoPause7dThreshold.value = typeof extra?.auto_pause_7d_threshold === 'number' ? extra.auto_pause_7d_threshold * 100 : null
autoPause5hDisabled.value = extra?.auto_pause_5h_disabled === true
autoPause7dDisabled.value = extra?.auto_pause_7d_disabled === true
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
openAIResponsesMode.value = 'auto'
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
openAICompactModelMappings.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
codexCLIOnlyAllowClaudeCodeEnabled.value = false
codexImageGenerationBridgeMode.value = 'inherit'
anthropicPassthroughEnabled.value = false
webSearchEmulationMode.value = 'default'
@ -2736,6 +2970,12 @@ const syncFormFromAccount = (newAccount: Account | null) => {
openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
if (newAccount.type === 'apikey') {
openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode)
openAIEndpointCapabilities.value = readOpenAIEndpointCapabilities(
newAccount.credentials as Record<string, unknown> | undefined
)
if (!openAITextGenerationCapabilityEnabled.value) {
openAIResponsesMode.value = 'auto'
}
}
const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean'
? extra.codex_image_generation_bridge
@ -2759,6 +2999,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
})
if (newAccount.type === 'oauth') {
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
codexCLIOnlyAllowClaudeCodeEnabled.value =
Array.isArray(extra?.codex_cli_only_allowed_clients) &&
(extra.codex_cli_only_allowed_clients as unknown[]).includes('claude_code')
}
const credentials = newAccount.credentials as Record<string, unknown> | undefined
const compactMappings = credentials?.compact_model_mapping as Record<string, string> | undefined
@ -3476,6 +3719,7 @@ const handleSubmit = async () => {
newCredentials.model_mapping = currentCredentials.model_mapping
}
if (props.account.platform === 'openai') {
applyOpenAIEndpointCapabilities(newCredentials)
const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
if (compactModelMapping) {
newCredentials.compact_model_mapping = compactModelMapping
@ -3829,9 +4073,9 @@ const handleSubmit = async () => {
}
// For OpenAI OAuth/API Key accounts, handle passthrough mode in extra
if (props.account.platform === 'openai' && (props.account.type === 'oauth' || props.account.type === 'apikey')) {
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
const newExtra: Record<string, unknown> = { ...currentExtra }
if (props.account.platform === 'openai' && (props.account.type === 'oauth' || props.account.type === 'apikey')) {
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
const newExtra: Record<string, unknown> = { ...currentExtra }
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
if (props.account.type === 'oauth') {
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
@ -3853,15 +4097,35 @@ const handleSubmit = async () => {
} else {
newExtra.openai_compact_mode = openAICompactMode.value
}
if (props.account.type === 'apikey') {
if (openAIResponsesMode.value === 'auto') {
if (props.account.type === 'apikey') {
if (!openAITextGenerationCapabilityEnabled.value || openAIResponsesMode.value === 'auto') {
delete newExtra.openai_responses_mode
} else {
newExtra.openai_responses_mode = openAIResponsesMode.value
}
}
}
if (autoPause5hThreshold.value != null && autoPause5hThreshold.value > 0) {
newExtra.auto_pause_5h_threshold = autoPause5hThreshold.value / 100
} else {
delete newExtra.auto_pause_5h_threshold
}
if (autoPause7dThreshold.value != null && autoPause7dThreshold.value > 0) {
newExtra.auto_pause_7d_threshold = autoPause7dThreshold.value / 100
} else {
delete newExtra.auto_pause_7d_threshold
}
if (autoPause5hDisabled.value) {
newExtra.auto_pause_5h_disabled = true
} else {
delete newExtra.auto_pause_5h_disabled
}
if (autoPause7dDisabled.value) {
newExtra.auto_pause_7d_disabled = true
} else {
delete newExtra.auto_pause_7d_disabled
}
delete newExtra.codex_image_generation_bridge_enabled
delete newExtra.codex_image_generation_bridge_enabled
if (codexImageGenerationBridgeMode.value === 'inherit') {
delete newExtra.codex_image_generation_bridge
} else {
@ -3877,6 +4141,12 @@ const handleSubmit = async () => {
} else {
delete newExtra.codex_cli_only
}
// codex_cli_only Claude Code
if (codexCLIOnlyEnabled.value && codexCLIOnlyAllowClaudeCodeEnabled.value) {
newExtra.codex_cli_only_allowed_clients = ['claude_code']
} else {
delete newExtra.codex_cli_only_allowed_clients
}
}
updatePayload.extra = newExtra

View File

@ -197,6 +197,25 @@ describe('BulkEditAccountModal', () => {
})
})
it('OpenAI OAuth 批量编辑应提交 codex_cli_only_allowed_clients 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
selectedTypes: ['oauth']
})
await wrapper.get('#bulk-edit-openai-codex-allow-claude-code-enabled').setValue(true)
await wrapper.get('#bulk-edit-openai-codex-allow-claude-code-toggle').trigger('click')
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
await flushPromises()
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
extra: {
codex_cli_only_allowed_clients: ['claude_code']
}
})
})
it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],

View File

@ -310,6 +310,137 @@ describe('EditAccountModal', () => {
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
})
it('submits OpenAI APIKey endpoint capabilities from credentials', async () => {
const account = buildAccount()
account.credentials.openai_capabilities = ['chat_completions']
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
expect(wrapper.findAll('input[type="checkbox"]').some((input) => (input.element as HTMLInputElement).checked)).toBe(true)
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
'chat_completions'
])
})
it('submits OpenAI quota auto-pause thresholds in extra', async () => {
const account = buildAccount()
account.extra = {
auto_pause_5h_threshold: 0.9,
auto_pause_7d_threshold: 0.8
}
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
await wrapper.get('[data-testid="auto-pause-5h-threshold"]').setValue('95')
await wrapper.get('[data-testid="auto-pause-7d-threshold"]').setValue('96')
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_5h_threshold).toBe(0.95)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_7d_threshold).toBe(0.96)
})
it('submits OpenAI quota auto-pause disable flag in extra', async () => {
// Toggling the per-account disable flag must persist as auto_pause_5h_disabled
// so an admin can exempt one account from auto-pause even when a global default
// threshold is configured (otherwise leaving the threshold blank would silently
// fall back to the global default).
const account = buildAccount()
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
await wrapper.get('[data-testid="auto-pause-5h-disabled"]').trigger('click')
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_5h_disabled).toBe(true)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_7d_disabled).toBeUndefined()
})
it('keeps at least one OpenAI APIKey endpoint capability selected', async () => {
const account = buildAccount()
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
const chatCheckbox = wrapper.get<HTMLInputElement>(
'[data-testid="openai-endpoint-capability-chat_completions"]'
)
const embeddingsCheckbox = wrapper.get<HTMLInputElement>(
'[data-testid="openai-endpoint-capability-embeddings"]'
)
expect(chatCheckbox.element.checked).toBe(true)
expect(embeddingsCheckbox.element.checked).toBe(true)
await embeddingsCheckbox.setValue(false)
expect(chatCheckbox.element.checked).toBe(true)
expect(embeddingsCheckbox.element.checked).toBe(false)
await chatCheckbox.setValue(false)
expect(chatCheckbox.element.checked).toBe(true)
expect(embeddingsCheckbox.element.checked).toBe(false)
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
'chat_completions'
])
})
it('disables text generation protocol when only embeddings requests are accepted', async () => {
const account = buildAccount()
account.credentials.openai_capabilities = ['embeddings']
account.extra = {
openai_responses_mode: 'force_responses',
openai_responses_supported: true
}
updateAccountMock.mockReset()
checkMixedChannelRiskMock.mockReset()
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
updateAccountMock.mockResolvedValue(account)
const wrapper = mountModal(account)
const responsesModeSelect = wrapper.get<HTMLSelectElement>(
'[data-testid="openai-responses-mode-select"]'
)
expect(responsesModeSelect.element.disabled).toBe(true)
expect(wrapper.find('[data-testid="openai-responses-mode-not-applicable"]').exists()).toBe(true)
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
expect(updateAccountMock).toHaveBeenCalledTimes(1)
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
'embeddings'
])
expect(updateAccountMock.mock.calls[0]?.[1]?.extra).not.toHaveProperty('openai_responses_mode')
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
})
it('submits account-level Codex image generation bridge override', async () => {
const account = buildAccount()
account.extra = {

View File

@ -35,6 +35,11 @@ describe('useModelWhitelist', () => {
expect(models).toContain('gemini-3-pro-image')
})
it('Claude 模型列表包含 Opus 4.8', () => {
expect(getModelsByPlatform('claude')).toContain('claude-opus-4-8')
expect(getModelsByPlatform('antigravity')).toContain('claude-opus-4-8')
})
it('gemini 模型列表包含原生生图模型', () => {
const models = getModelsByPlatform('gemini')

View File

@ -29,6 +29,7 @@ export const claudeModels = [
'claude-opus-4-5-20251101',
'claude-opus-4-6',
'claude-opus-4-7',
'claude-opus-4-8',
'claude-sonnet-4-6'
]
@ -53,6 +54,7 @@ const antigravityModels = [
'claude-opus-4-6',
'claude-opus-4-6-thinking',
'claude-opus-4-7',
'claude-opus-4-8',
'claude-opus-4-5-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-5',
@ -263,6 +265,7 @@ const anthropicPresetMappings = [
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'claude-opus-4-8', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
{ label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4-5-20251001', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
{ label: 'Opus->Sonnet', from: 'claude-opus-4-6', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
@ -322,7 +325,8 @@ const antigravityPresetMappings = [
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'claude-opus-4-8', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]
// Windsurf 预设映射
@ -332,6 +336,7 @@ const windsurfPresetMappings: { label: string; from: string; to: string; color:
const bedrockPresetMappings = [
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'us.anthropic.claude-opus-4-7-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'us.anthropic.claude-opus-4-8-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },

View File

@ -3361,10 +3361,21 @@ export default {
'Automatic passthrough is currently enabled: it only affects HTTP passthrough and does not disable WS mode.',
responsesMode: 'Responses API support',
responsesModeDesc:
'Only applies to OpenAI API Key accounts. Auto follows probe results; force modes override probing.',
'Only applies to the OpenAI API Key text forwarding path. Auto follows probe results; force modes override probing.',
responsesModeAuto: 'Auto',
responsesModeForceResponses: 'Force Responses',
responsesModeForceChatCompletions: 'Force Chat Completions',
responsesModeTextDisabledHint:
'Not applicable when the Responses / Chat Completions endpoint is not enabled.',
endpointCapabilities: 'Endpoint capabilities',
endpointCapabilitiesDesc:
'Used by account routing. The text endpoint follows the Responses API support setting above and is shown as Responses, Chat Completions, or auto mode; Embeddings independently controls /v1/embeddings.',
capabilityResponses: 'Responses',
capabilityTextAuto: 'Responses / Chat Completions (Auto)',
capabilityResponsesAuto: 'Responses (auto probe)',
capabilityChatCompletions: 'Chat Completions',
capabilityChatCompletionsAuto: 'Chat Completions (auto probe)',
capabilityEmbeddings: 'Embeddings',
responsesStatusAutoSupported: 'Auto probe: Responses',
responsesStatusAutoUnsupported: 'Auto probe: Chat Completions',
responsesStatusAutoUnknown: 'Auto probe: unknown',
@ -3373,6 +3384,9 @@ export default {
codexCLIOnly: 'Codex official clients only',
codexCLIOnlyDesc:
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
codexCLIOnlyAllowClaudeCode: "Also allow Claude Code's Codex plugin",
codexCLIOnlyAllowClaudeCodeDesc:
'Only takes effect when the switch above is on. Additionally allows requests from the Claude Code Codex plugin (exact match on originator=Claude Code) without weakening blocking of other non-official clients.',
codexImageGenerationBridge: 'Codex image-generation bridge',
codexImageGenerationBridgeDesc:
'Account policy takes precedence over channel and global settings. Only controls whether Codex requests through the /responses text endpoint receive the image_generation tool; standalone image-generation endpoints are unaffected.',
@ -3473,6 +3487,12 @@ export default {
'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
autoPauseOnExpired: 'Auto Pause On Expired',
autoPauseOnExpiredDesc: 'When enabled, the account will auto pause scheduling after it expires',
autoPause5hThreshold: '5h Usage Threshold (%)',
autoPause7dThreshold: '7d Usage Threshold (%)',
autoPauseThresholdHint: 'Leave empty or set 0 to use the global default threshold (configured in Ops settings); set a value to override the global default. Reaching the threshold only skips the account during scheduling and does not modify schedulable.',
autoPause5hDisabled: 'Disable 5h auto-pause',
autoPause7dDisabled: 'Disable 7d auto-pause',
autoPauseDisabledHint: 'When enabled, this account is never auto-paused (even if a global default threshold is configured).',
// Quota control (Anthropic OAuth/SetupToken only)
quotaControl: {
title: 'Quota Control',
@ -5203,6 +5223,11 @@ export default {
aggregation: 'Pre-aggregation Tasks',
enableAggregation: 'Enable Pre-aggregation',
aggregationHint: 'Pre-aggregation improves query performance for long time windows',
openaiQuotaAutoPause: 'OpenAI Account Quota Auto-pause',
openaiQuotaAutoPauseHint: 'When an OpenAI account reaches its 5h / 7d usage threshold, the scheduler skips it automatically and resumes once the window rolls over. Per-account thresholds take precedence over this global default.',
openaiQuotaAutoPauseDefault5h: 'Default 5h usage threshold (%)',
openaiQuotaAutoPauseDefault7d: 'Default 7d usage threshold (%)',
openaiQuotaAutoPauseThresholdHint: 'Value 0-100; leave blank or 0 to disable the global default threshold.',
errorFiltering: 'Error Filtering',
ignoreCountTokensErrors: 'Ignore count_tokens errors',
ignoreCountTokensErrorsHint: 'When enabled, errors from count_tokens requests will not be written to the error log.',
@ -5233,7 +5258,8 @@ export default {
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
upstreamErrorRateMaxRange: 'Upstream error rate maximum must be between 0 and 100'
upstreamErrorRateMaxRange: 'Upstream error rate maximum must be between 0 and 100',
openaiQuotaAutoPauseRange: 'OpenAI quota auto-pause threshold must be between 0 and 100'
}
},
concurrency: {
@ -5627,6 +5653,9 @@ export default {
openaiCodexUserAgent: 'OpenAI Codex UA',
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
openaiAllowClaudeCodeCodexPlugin: "Allow using the Codex plugin in Claude Code",
openaiAllowClaudeCodeCodexPluginDesc:
"Global switch; only affects OpenAI OAuth accounts that have 'Codex official clients only' enabled. When on, all such accounts additionally allow requests from the Claude Code Codex plugin (exact match on originator=Claude Code) without per-account config; upstream requests remain pass-through.",
},
webSearchEmulation: {
title: 'Web Search Emulation',

View File

@ -3507,10 +3507,20 @@ export default {
responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
responsesMode: 'Responses API 支持',
responsesModeDesc:
'仅对 OpenAI API Key 生效。自动跟随探测结果,强制模式会覆盖自动探测。',
'仅对 OpenAI API Key 的文本转发链路生效。自动跟随探测结果,强制模式会覆盖自动探测。',
responsesModeAuto: '自动',
responsesModeForceResponses: '强制 Responses',
responsesModeForceChatCompletions: '强制 Chat Completions',
responsesModeTextDisabledHint: '未启用 Responses / Chat Completions 端点时,此设置不适用。',
endpointCapabilities: '端点能力',
endpointCapabilitiesDesc:
'用于调度筛选。文本端点会跟随上方 Responses API 支持显示为 Responses、Chat Completions 或自动模式Embeddings 独立控制 /v1/embeddings。',
capabilityResponses: 'Responses',
capabilityTextAuto: 'Responses / Chat Completions自动',
capabilityResponsesAuto: 'Responses自动探测',
capabilityChatCompletions: 'Chat Completions',
capabilityChatCompletionsAuto: 'Chat Completions自动探测',
capabilityEmbeddings: 'Embeddings',
responsesStatusAutoSupported: '自动探测Responses',
responsesStatusAutoUnsupported: '自动探测Chat Completions',
responsesStatusAutoUnknown: '自动探测:未探测',
@ -3518,6 +3528,8 @@ export default {
responsesStatusForcedChatCompletions: '已强制 Chat Completions',
codexCLIOnly: '仅允许 Codex 官方客户端',
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
codexCLIOnlyAllowClaudeCode: '额外放行 Claude Code 的 Codex 插件',
codexCLIOnlyAllowClaudeCodeDesc: '仅在上方开关开启时生效。额外放行通过 Claude Code 的 Codex 插件发起的请求(精确匹配 originator=Claude Code不影响对其他非官方客户端的拦截。',
codexImageGenerationBridge: 'Codex 图片生成桥接',
codexImageGenerationBridgeDesc:
'账号级策略优先于渠道和全局配置。仅控制 Codex 走 /responses 文本端点时是否注入 image_generation 工具;不影响独立图片生成接口。',
@ -3613,6 +3625,12 @@ export default {
interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token',
autoPauseOnExpired: '过期自动暂停调度',
autoPauseOnExpiredDesc: '启用后,账号过期将自动暂停调度',
autoPause5hThreshold: '5h 用量阈值(%)',
autoPause7dThreshold: '7d 用量阈值(%)',
autoPauseThresholdHint: '留空或填 0 表示使用全局默认阈值(在运维设置中配置);填具体值则覆盖全局默认。达到阈值后仅在调度时跳过账号,不修改 schedulable。',
autoPause5hDisabled: '禁用 5h 自动暂停',
autoPause7dDisabled: '禁用 7d 自动暂停',
autoPauseDisabledHint: '开启后该账号永不进入自动暂停(即使全局默认阈值已配置)。',
// Quota control (Anthropic OAuth/SetupToken only)
quotaControl: {
title: '配额控制',
@ -5364,6 +5382,11 @@ export default {
aggregation: '预聚合任务',
enableAggregation: '启用预聚合任务',
aggregationHint: '预聚合可提升长时间窗口查询性能',
openaiQuotaAutoPause: 'OpenAI 账号配额自动暂停',
openaiQuotaAutoPauseHint: '当 OpenAI 账号 5h / 7d 用量达到阈值时,调度会自动跳过该账号;窗口滚动后自动恢复。账号级阈值优先于此全局默认值。',
openaiQuotaAutoPauseDefault5h: '默认 5h 用量阈值 (%)',
openaiQuotaAutoPauseDefault7d: '默认 7d 用量阈值 (%)',
openaiQuotaAutoPauseThresholdHint: '取值 0-100留空或 0 表示不启用全局默认阈值。',
errorFiltering: '错误过滤',
ignoreCountTokensErrors: '忽略 count_tokens 错误',
ignoreCountTokensErrorsHint: '启用后count_tokens 请求的错误将不会写入错误日志。',
@ -5395,7 +5418,8 @@ export default {
slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
upstreamErrorRateMaxRange: '上游错误率最大值必须在0-100之间'
upstreamErrorRateMaxRange: '上游错误率最大值必须在0-100之间',
openaiQuotaAutoPauseRange: 'OpenAI 配额自动暂停阈值必须在 0-100 之间'
}
},
concurrency: {
@ -5783,6 +5807,9 @@ export default {
openaiCodexUserAgent: 'OpenAI Codex UA',
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
openaiAllowClaudeCodeCodexPlugin: '允许在 Claude Code 中使用 Codex 插件',
openaiAllowClaudeCodeCodexPluginDesc:
'全局开关,仅对已开启「仅允许 Codex 官方客户端」的 OpenAI OAuth 账号生效。开启后,所有此类账号都额外放行通过 Claude Code 的 Codex 插件发起的请求(精确匹配 originator=Claude Code无需逐账号配置上游请求仍保持透传。',
},
webSearchEmulation: {
title: 'Web Search 模拟',

View File

@ -1185,6 +1185,7 @@ export interface CodexUsageSnapshot {
export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off'
export type OpenAIResponsesMode = 'auto' | 'force_responses' | 'force_chat_completions'
export type OpenAIEndpointCapability = 'chat_completions' | 'embeddings'
export interface OpenAICompactState {
openai_compact_mode?: OpenAICompactMode

View File

@ -3948,6 +3948,19 @@
}}
</p>
</div>
<!-- 是否允许在 Claude Code 中使用 Codex 插件全局开关 -->
<div class="flex items-center justify-between">
<div class="pr-4">
<label class="block text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t("admin.settings.gatewayForwarding.openaiAllowClaudeCodeCodexPlugin") }}
</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t("admin.settings.gatewayForwarding.openaiAllowClaudeCodeCodexPluginDesc") }}
</p>
</div>
<Toggle v-model="form.openai_allow_claude_code_codex_plugin" />
</div>
</div>
</div>
<!-- Web Search Emulation -->
@ -7163,6 +7176,7 @@ const form = reactive<SettingsForm>({
rewrite_message_cache_control: false,
antigravity_user_agent_version: "",
openai_codex_user_agent: "",
openai_allow_claude_code_codex_plugin: false,
//
balance_low_notify_enabled: false,
balance_low_notify_threshold: 0,
@ -8269,6 +8283,7 @@ async function saveSettings() {
form.antigravity_user_agent_version?.trim() || "",
openai_codex_user_agent:
form.openai_codex_user_agent?.trim() || "",
openai_allow_claude_code_codex_plugin: form.openai_allow_claude_code_codex_plugin,
// Payment configuration
payment_enabled: form.payment_enabled,
risk_control_enabled: form.risk_control_enabled,

View File

@ -50,6 +50,10 @@ async function loadAllSettings() {
runtimeSettings.value = runtime
emailConfig.value = email
advancedSettings.value = advanced
// payload
if (advancedSettings.value && !advancedSettings.value.openai_account_quota_auto_pause) {
advancedSettings.value.openai_account_quota_auto_pause = { default_threshold_5h: 0, default_threshold_7d: 0 }
}
// 使
if (thresholds && Object.keys(thresholds).length > 0) {
metricThresholds.value = {
@ -119,6 +123,28 @@ function removeRecipient(target: 'alert' | 'report', email: string) {
if (idx >= 0) list.splice(idx, 1)
}
// OpenAI 0~1 UI (0~100)
const quotaAutoPause5hPercent = computed<number | null>({
get() {
const v = advancedSettings.value?.openai_account_quota_auto_pause?.default_threshold_5h
return v && v > 0 ? Math.round(v * 1000) / 10 : null
},
set(val) {
if (!advancedSettings.value?.openai_account_quota_auto_pause) return
advancedSettings.value.openai_account_quota_auto_pause.default_threshold_5h = val != null && val > 0 ? val / 100 : 0
}
})
const quotaAutoPause7dPercent = computed<number | null>({
get() {
const v = advancedSettings.value?.openai_account_quota_auto_pause?.default_threshold_7d
return v && v > 0 ? Math.round(v * 1000) / 10 : null
},
set(val) {
if (!advancedSettings.value?.openai_account_quota_auto_pause) return
advancedSettings.value.openai_account_quota_auto_pause.default_threshold_7d = val != null && val > 0 ? val / 100 : 0
}
})
//
const validation = computed(() => {
const errors: string[] = []
@ -145,6 +171,11 @@ const validation = computed(() => {
if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) {
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
}
const { default_threshold_5h, default_threshold_7d } = advancedSettings.value.openai_account_quota_auto_pause
if (default_threshold_5h < 0 || default_threshold_5h > 1 || default_threshold_7d < 0 || default_threshold_7d > 1) {
errors.push(t('admin.ops.settings.validation.openaiQuotaAutoPauseRange'))
}
}
//
@ -473,6 +504,40 @@ async function saveAllSettings() {
</div>
</div>
<!-- OpenAI 账号配额自动暂停全局默认阈值 -->
<div class="space-y-3">
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.openaiQuotaAutoPause') }}</h5>
<p class="text-xs text-gray-500">{{ t('admin.ops.settings.openaiQuotaAutoPauseHint') }}</p>
<div class="grid grid-cols-1 gap-4 md:grid-cols-2">
<div>
<label class="input-label">{{ t('admin.ops.settings.openaiQuotaAutoPauseDefault5h') }}</label>
<input
v-model.number="quotaAutoPause5hPercent"
type="number"
min="0"
max="100"
step="0.1"
class="input"
data-testid="ops-quota-auto-pause-5h"
/>
</div>
<div>
<label class="input-label">{{ t('admin.ops.settings.openaiQuotaAutoPauseDefault7d') }}</label>
<input
v-model.number="quotaAutoPause7dPercent"
type="number"
min="0"
max="100"
step="0.1"
class="input"
data-testid="ops-quota-auto-pause-7d"
/>
</div>
</div>
<p class="text-xs text-gray-500">{{ t('admin.ops.settings.openaiQuotaAutoPauseThresholdHint') }}</p>
</div>
<!-- Error Filtering -->
<div class="space-y-3">
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.errorFiltering') }}</h5>