chore: merge upstream Wei-Shaw/sub2api v0.1.133
This commit is contained in:
commit
a420179abb
@ -1 +1 @@
|
|||||||
0.1.132
|
0.1.133
|
||||||
|
|||||||
@ -199,7 +199,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.ProvideOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink, settingService)
|
||||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -22,18 +22,16 @@ const (
|
|||||||
PlatformOpenAI = "openai"
|
PlatformOpenAI = "openai"
|
||||||
PlatformGemini = "gemini"
|
PlatformGemini = "gemini"
|
||||||
PlatformAntigravity = "antigravity"
|
PlatformAntigravity = "antigravity"
|
||||||
PlatformWindsurf = "windsurf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
const (
|
const (
|
||||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
AccountTypeWindsurfSession = "windsurf-session" // Windsurf Session 类型账号(邮箱密码登录获取的 session token + api_key)
|
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
|
||||||
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@ -74,7 +72,8 @@ const (
|
|||||||
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
|
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
|
||||||
var DefaultAntigravityModelMapping = map[string]string{
|
var DefaultAntigravityModelMapping = map[string]string{
|
||||||
// Claude 白名单
|
// Claude 白名单
|
||||||
"claude-opus-4-7": "claude-opus-4-6", // 官方模型
|
"claude-opus-4-8": "claude-opus-4-8", // 官方模型
|
||||||
|
"claude-opus-4-7": "claude-opus-4-7", // 官方模型
|
||||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
||||||
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
||||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||||
@ -124,6 +123,7 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||||
var DefaultBedrockModelMapping = map[string]string{
|
var DefaultBedrockModelMapping = map[string]string{
|
||||||
// Claude Opus
|
// Claude Opus
|
||||||
|
"claude-opus-4-8": "us.anthropic.claude-opus-4-8-v1",
|
||||||
"claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1",
|
"claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1",
|
||||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
|||||||
@ -24,3 +24,27 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAntigravityModelMapping_ContainsOpus48(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, ok := DefaultAntigravityModelMapping["claude-opus-4-8"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected mapping for claude-opus-4-8 to exist")
|
||||||
|
}
|
||||||
|
if got != "claude-opus-4-8" {
|
||||||
|
t.Fatalf("unexpected claude-opus-4-8 mapping: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultBedrockModelMapping_ContainsOpus48(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, ok := DefaultBedrockModelMapping["claude-opus-4-8"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected Bedrock mapping for claude-opus-4-8 to exist")
|
||||||
|
}
|
||||||
|
if got != "us.anthropic.claude-opus-4-8-v1" {
|
||||||
|
t.Fatalf("unexpected Bedrock claude-opus-4-8 mapping: got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -256,6 +256,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
RewriteMessageCacheControl: settings.RewriteMessageCacheControl,
|
||||||
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
AntigravityUserAgentVersion: settings.AntigravityUserAgentVersion,
|
||||||
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
|
OpenAICodexUserAgent: settings.OpenAICodexUserAgent,
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin: settings.OpenAIAllowClaudeCodeCodexPlugin,
|
||||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||||
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
|
||||||
@ -584,6 +585,7 @@ type UpdateSettingsRequest struct {
|
|||||||
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
RewriteMessageCacheControl *bool `json:"rewrite_message_cache_control"`
|
||||||
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
AntigravityUserAgentVersion *string `json:"antigravity_user_agent_version"`
|
||||||
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
|
OpenAICodexUserAgent *string `json:"openai_codex_user_agent"`
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin *bool `json:"openai_allow_claude_code_codex_plugin"`
|
||||||
|
|
||||||
// Payment visible method routing
|
// Payment visible method routing
|
||||||
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
|
||||||
@ -1655,6 +1657,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return previousSettings.OpenAICodexUserAgent
|
return previousSettings.OpenAICodexUserAgent
|
||||||
}(),
|
}(),
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin: func() bool {
|
||||||
|
if req.OpenAIAllowClaudeCodeCodexPlugin != nil {
|
||||||
|
return *req.OpenAIAllowClaudeCodeCodexPlugin
|
||||||
|
}
|
||||||
|
return previousSettings.OpenAIAllowClaudeCodeCodexPlugin
|
||||||
|
}(),
|
||||||
PaymentVisibleMethodAlipaySource: func() string {
|
PaymentVisibleMethodAlipaySource: func() string {
|
||||||
if req.PaymentVisibleMethodAlipaySource != nil {
|
if req.PaymentVisibleMethodAlipaySource != nil {
|
||||||
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
|
||||||
@ -2031,6 +2039,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
RewriteMessageCacheControl: updatedSettings.RewriteMessageCacheControl,
|
||||||
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
AntigravityUserAgentVersion: updatedSettings.AntigravityUserAgentVersion,
|
||||||
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
|
OpenAICodexUserAgent: updatedSettings.OpenAICodexUserAgent,
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin: updatedSettings.OpenAIAllowClaudeCodeCodexPlugin,
|
||||||
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
|
||||||
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
|
||||||
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
|
||||||
@ -2500,6 +2509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
|
if before.OpenAICodexUserAgent != after.OpenAICodexUserAgent {
|
||||||
changed = append(changed, "openai_codex_user_agent")
|
changed = append(changed, "openai_codex_user_agent")
|
||||||
}
|
}
|
||||||
|
if before.OpenAIAllowClaudeCodeCodexPlugin != after.OpenAIAllowClaudeCodeCodexPlugin {
|
||||||
|
changed = append(changed, "openai_allow_claude_code_codex_plugin")
|
||||||
|
}
|
||||||
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
|
||||||
changed = append(changed, "payment_visible_method_alipay_source")
|
changed = append(changed, "payment_visible_method_alipay_source")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -17,12 +18,18 @@ import (
|
|||||||
|
|
||||||
// SystemHandler handles system-related operations
|
// SystemHandler handles system-related operations
|
||||||
type SystemHandler struct {
|
type SystemHandler struct {
|
||||||
updateSvc *service.UpdateService
|
updateSvc systemUpdateService
|
||||||
lockSvc *service.SystemOperationLockService
|
lockSvc *service.SystemOperationLockService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type systemUpdateService interface {
|
||||||
|
CheckUpdate(ctx context.Context, force bool) (*service.UpdateInfo, error)
|
||||||
|
PerformUpdate(ctx context.Context) error
|
||||||
|
Rollback() error
|
||||||
|
}
|
||||||
|
|
||||||
// NewSystemHandler creates a new SystemHandler
|
// NewSystemHandler creates a new SystemHandler
|
||||||
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
|
func NewSystemHandler(updateSvc systemUpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
|
||||||
return &SystemHandler{
|
return &SystemHandler{
|
||||||
updateSvc: updateSvc,
|
updateSvc: updateSvc,
|
||||||
lockSvc: lockSvc,
|
lockSvc: lockSvc,
|
||||||
@ -67,6 +74,21 @@ func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
|
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
|
||||||
|
if errors.Is(err, service.ErrNoUpdateAvailable) {
|
||||||
|
info, checkErr := h.updateSvc.CheckUpdate(ctx, false)
|
||||||
|
if checkErr != nil {
|
||||||
|
releaseReason = "SYSTEM_UPDATE_FAILED"
|
||||||
|
return nil, checkErr
|
||||||
|
}
|
||||||
|
succeeded = true
|
||||||
|
return gin.H{
|
||||||
|
"message": "Already up to date",
|
||||||
|
"already_up_to_date": true,
|
||||||
|
"current_version": info.CurrentVersion,
|
||||||
|
"latest_version": info.LatestVersion,
|
||||||
|
"operation_id": lock.OperationID(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
releaseReason = "SYSTEM_UPDATE_FAILED"
|
releaseReason = "SYSTEM_UPDATE_FAILED"
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
144
backend/internal/handler/admin/system_handler_test.go
Normal file
144
backend/internal/handler/admin/system_handler_test.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type systemHandlerUpdateServiceStub struct {
|
||||||
|
performErr error
|
||||||
|
updateInfo *service.UpdateInfo
|
||||||
|
checkErr error
|
||||||
|
checkForces []bool
|
||||||
|
performCall int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemHandlerUpdateServiceStub) CheckUpdate(_ context.Context, force bool) (*service.UpdateInfo, error) {
|
||||||
|
s.checkForces = append(s.checkForces, force)
|
||||||
|
return s.updateInfo, s.checkErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemHandlerUpdateServiceStub) PerformUpdate(context.Context) error {
|
||||||
|
s.performCall++
|
||||||
|
return s.performErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *systemHandlerUpdateServiceStub) Rollback() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type systemUpdateResponseEnvelope struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
AlreadyUpToDate bool `json:"already_up_to_date"`
|
||||||
|
CurrentVersion string `json:"current_version"`
|
||||||
|
LatestVersion string `json:"latest_version"`
|
||||||
|
OperationID string `json:"operation_id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type systemUpdateErrorEnvelope struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSystemHandlerTestRouter(t *testing.T, updateSvc *systemHandlerUpdateServiceStub, repo *memoryIdempotencyRepoStub) *gin.Engine {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
service.SetDefaultIdempotencyCoordinator(nil)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
service.SetDefaultIdempotencyCoordinator(nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
lockSvc := service.NewSystemOperationLockService(repo, service.IdempotencyConfig{
|
||||||
|
ProcessingTimeout: time.Second,
|
||||||
|
SystemOperationTTL: time.Minute,
|
||||||
|
})
|
||||||
|
handler := NewSystemHandler(updateSvc, lockSvc)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.POST("/api/v1/admin/system/update", handler.PerformUpdate)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireSystemLockStatus(t *testing.T, repo *memoryIdempotencyRepoStub, wantStatus string) {
|
||||||
|
t.Helper()
|
||||||
|
repo.mu.Lock()
|
||||||
|
defer repo.mu.Unlock()
|
||||||
|
|
||||||
|
for _, record := range repo.data {
|
||||||
|
if record.Status == wantStatus {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("system lock status %q not found in records: %#v", wantStatus, repo.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemHandlerPerformUpdateAlreadyUpToDateReturnsOK(t *testing.T) {
|
||||||
|
updateSvc := &systemHandlerUpdateServiceStub{
|
||||||
|
performErr: service.ErrNoUpdateAvailable,
|
||||||
|
updateInfo: &service.UpdateInfo{
|
||||||
|
CurrentVersion: "0.1.132",
|
||||||
|
LatestVersion: "0.1.132",
|
||||||
|
HasUpdate: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := newMemoryIdempotencyRepoStub()
|
||||||
|
router := newSystemHandlerTestRouter(t, updateSvc, repo)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/system/update", nil)
|
||||||
|
req.Header.Set("Idempotency-Key", "already-up-to-date")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, 1, updateSvc.performCall)
|
||||||
|
require.Equal(t, []bool{false}, updateSvc.checkForces)
|
||||||
|
requireSystemLockStatus(t, repo, service.IdempotencyStatusSucceeded)
|
||||||
|
|
||||||
|
var body systemUpdateResponseEnvelope
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body))
|
||||||
|
require.Equal(t, 0, body.Code)
|
||||||
|
require.Equal(t, "success", body.Message)
|
||||||
|
require.Equal(t, "Already up to date", body.Data.Message)
|
||||||
|
require.True(t, body.Data.AlreadyUpToDate)
|
||||||
|
require.Equal(t, "0.1.132", body.Data.CurrentVersion)
|
||||||
|
require.Equal(t, "0.1.132", body.Data.LatestVersion)
|
||||||
|
require.NotEmpty(t, body.Data.OperationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSystemHandlerPerformUpdateFailureStillReturnsInternalError(t *testing.T) {
|
||||||
|
updateSvc := &systemHandlerUpdateServiceStub{
|
||||||
|
performErr: errors.New("download failed"),
|
||||||
|
}
|
||||||
|
repo := newMemoryIdempotencyRepoStub()
|
||||||
|
router := newSystemHandlerTestRouter(t, updateSvc, repo)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/system/update", nil)
|
||||||
|
req.Header.Set("Idempotency-Key", "real-failure")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
require.Equal(t, 1, updateSvc.performCall)
|
||||||
|
require.Empty(t, updateSvc.checkForces)
|
||||||
|
requireSystemLockStatus(t, repo, service.IdempotencyStatusFailedRetryable)
|
||||||
|
|
||||||
|
var body systemUpdateErrorEnvelope
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body))
|
||||||
|
require.Equal(t, http.StatusInternalServerError, body.Code)
|
||||||
|
require.Equal(t, "internal error", body.Message)
|
||||||
|
}
|
||||||
27
backend/internal/handler/concurrency_error_response.go
Normal file
27
backend/internal/handler/concurrency_error_response.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const statusClientClosedRequest = 499
|
||||||
|
|
||||||
|
func concurrencyErrorResponse(err error, slotType string) (int, string, string) {
|
||||||
|
var concurrencyErr *ConcurrencyError
|
||||||
|
if errors.As(err, &concurrencyErr) {
|
||||||
|
if concurrencyErr.SlotType != "" {
|
||||||
|
slotType = concurrencyErr.SlotType
|
||||||
|
}
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error",
|
||||||
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return statusClientClosedRequest, "api_error", "context canceled"
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable, please retry later"
|
||||||
|
}
|
||||||
63
backend/internal/handler/concurrency_error_response_test.go
Normal file
63
backend/internal/handler/concurrency_error_response_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConcurrencyErrorResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
slotType string
|
||||||
|
wantStatus int
|
||||||
|
wantType string
|
||||||
|
wantMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "true concurrency timeout remains rate limit",
|
||||||
|
err: &ConcurrencyError{SlotType: "account", IsTimeout: true},
|
||||||
|
slotType: "user",
|
||||||
|
wantStatus: http.StatusTooManyRequests,
|
||||||
|
wantType: "rate_limit_error",
|
||||||
|
wantMessage: "Concurrency limit exceeded for account, please retry later",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client cancellation is not classified as concurrency limit",
|
||||||
|
err: context.Canceled,
|
||||||
|
slotType: "user",
|
||||||
|
wantStatus: statusClientClosedRequest,
|
||||||
|
wantType: "api_error",
|
||||||
|
wantMessage: "context canceled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline exceeded is service unavailable",
|
||||||
|
err: context.DeadlineExceeded,
|
||||||
|
slotType: "user",
|
||||||
|
wantStatus: http.StatusServiceUnavailable,
|
||||||
|
wantType: "api_error",
|
||||||
|
wantMessage: "Service temporarily unavailable, please retry later",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "redis acquire error is service unavailable",
|
||||||
|
err: errors.New("redis unavailable"),
|
||||||
|
slotType: "user",
|
||||||
|
wantStatus: http.StatusServiceUnavailable,
|
||||||
|
wantType: "api_error",
|
||||||
|
wantMessage: "Service temporarily unavailable, please retry later",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
status, errType, message := concurrencyErrorResponse(tt.err, tt.slotType)
|
||||||
|
require.Equal(t, tt.wantStatus, status)
|
||||||
|
require.Equal(t, tt.wantType, errType)
|
||||||
|
require.Equal(t, tt.wantMessage, message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -185,6 +185,7 @@ type SystemSettings struct {
|
|||||||
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
RewriteMessageCacheControl bool `json:"rewrite_message_cache_control"`
|
||||||
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
AntigravityUserAgentVersion string `json:"antigravity_user_agent_version"`
|
||||||
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
|
OpenAICodexUserAgent string `json:"openai_codex_user_agent"`
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin bool `json:"openai_allow_claude_code_codex_plugin"`
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||||
|
|||||||
@ -535,7 +535,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
ParsedRequest: parsedReq,
|
||||||
@ -965,7 +965,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
|
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
ParsedRequest: parsedReq,
|
||||||
@ -1531,10 +1531,10 @@ func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, su
|
|||||||
return min
|
return min
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
// handleConcurrencyError handles concurrency-related acquire errors.
|
||||||
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
status, errType, message := concurrencyErrorResponse(err, slotType)
|
||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
h.handleStreamingAwareError(c, status, errType, message, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
@ -2138,10 +2138,11 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
func (h *GatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||||
if task == nil {
|
if task == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task = wrapUsageRecordTaskContext(parent, task)
|
||||||
if h.usageRecordWorkerPool != nil {
|
if h.usageRecordWorkerPool != nil {
|
||||||
h.usageRecordWorkerPool.Submit(task)
|
h.usageRecordWorkerPool.Submit(task)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -292,7 +292,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
QuotaPlatform: quotaPlatform,
|
QuotaPlatform: quotaPlatform,
|
||||||
|
|||||||
@ -267,7 +267,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
QuotaPlatform: quotaPlatform,
|
QuotaPlatform: quotaPlatform,
|
||||||
|
|||||||
@ -336,6 +336,9 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
if parentErr := c.Request.Context().Err(); parentErr != nil {
|
||||||
|
return nil, parentErr
|
||||||
|
}
|
||||||
return nil, &ConcurrencyError{
|
return nil, &ConcurrencyError{
|
||||||
SlotType: slotType,
|
SlotType: slotType,
|
||||||
IsTimeout: true,
|
IsTimeout: true,
|
||||||
|
|||||||
@ -280,6 +280,25 @@ func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWaitForSlotWithPingTimeout_ParentContextCanceled(t *testing.T) {
|
||||||
|
cache := &helperConcurrencyCacheStub{
|
||||||
|
accountSeq: []bool{false},
|
||||||
|
}
|
||||||
|
concurrency := service.NewConcurrencyService(cache)
|
||||||
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
reqCtx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
c.Request = c.Request.WithContext(reqCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
|
||||||
|
require.Nil(t, release)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
var cErr *ConcurrencyError
|
||||||
|
require.False(t, errors.As(err, &cErr))
|
||||||
|
}
|
||||||
|
|
||||||
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||||
errCache := &helperConcurrencyCacheStubWithError{
|
errCache := &helperConcurrencyCacheStubWithError{
|
||||||
err: errors.New("redis unavailable"),
|
err: errors.New("redis unavailable"),
|
||||||
|
|||||||
@ -528,7 +528,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
inboundEndpoint := GetInboundEndpoint(c)
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
QuotaPlatform: quotaPlatform,
|
QuotaPlatform: quotaPlatform,
|
||||||
|
|||||||
@ -127,7 +127,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
"",
|
"",
|
||||||
@ -135,6 +135,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
reqModel,
|
reqModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportAny,
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
service.OpenAIEndpointCapabilityChatCompletions,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -273,7 +274,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
inboundEndpoint := GetInboundEndpoint(c)
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
|
||||||
|
|
||||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
|
|||||||
@ -107,7 +107,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
|||||||
routingStart := time.Now()
|
routingStart := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, _, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
"",
|
"",
|
||||||
@ -115,6 +115,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
|||||||
reqModel,
|
reqModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportHTTPSSE,
|
service.OpenAIUpstreamTransportHTTPSSE,
|
||||||
|
service.OpenAIEndpointCapabilityEmbeddings,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -140,13 +141,6 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
if account.Type != service.AccountTypeAPIKey {
|
|
||||||
if selection.ReleaseFunc != nil {
|
|
||||||
selection.ReleaseFunc()
|
|
||||||
}
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
|
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
|
||||||
@ -220,7 +214,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
|||||||
inboundEndpoint := GetInboundEndpoint(c)
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
@ -46,6 +47,31 @@ func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedM
|
|||||||
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
|
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func usageRecordContext(parent context.Context, base context.Context) context.Context {
|
||||||
|
if base == nil {
|
||||||
|
base = context.Background()
|
||||||
|
}
|
||||||
|
if parent == nil {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
if clientRequestID, _ := parent.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||||
|
base = context.WithValue(base, ctxkey.ClientRequestID, strings.TrimSpace(clientRequestID))
|
||||||
|
}
|
||||||
|
if requestID, _ := parent.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||||
|
base = context.WithValue(base, ctxkey.RequestID, strings.TrimSpace(requestID))
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapUsageRecordTaskContext(parent context.Context, task service.UsageRecordTask) service.UsageRecordTask {
|
||||||
|
if task == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return func(ctx context.Context) {
|
||||||
|
task(usageRecordContext(parent, ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
func NewOpenAIGatewayHandler(
|
func NewOpenAIGatewayHandler(
|
||||||
gatewayService *service.OpenAIGatewayService,
|
gatewayService *service.OpenAIGatewayService,
|
||||||
@ -266,7 +292,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
// Select account supporting the requested model
|
// Select account supporting the requested model
|
||||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
previousResponseID,
|
previousResponseID,
|
||||||
@ -274,6 +300,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
reqModel,
|
reqModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportAny,
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
service.OpenAIEndpointCapabilityChatCompletions,
|
||||||
requireCompact,
|
requireCompact,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -437,7 +464,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
@ -675,7 +702,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
currentRoutingModel = effectiveMappedModel
|
currentRoutingModel = effectiveMappedModel
|
||||||
}
|
}
|
||||||
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
|
||||||
c.Request.Context(),
|
c.Request.Context(),
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
"", // no previous_response_id
|
"", // no previous_response_id
|
||||||
@ -683,6 +710,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
currentRoutingModel,
|
currentRoutingModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportAny,
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
service.OpenAIEndpointCapabilityChatCompletions,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -821,7 +849,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
inboundEndpoint := GetInboundEndpoint(c)
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
|
|
||||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
@ -1273,7 +1301,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability(
|
||||||
ctx,
|
ctx,
|
||||||
apiKey.GroupID,
|
apiKey.GroupID,
|
||||||
previousResponseID,
|
previousResponseID,
|
||||||
@ -1281,6 +1309,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
reqModel,
|
reqModel,
|
||||||
failedAccountIDs,
|
failedAccountIDs,
|
||||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||||
|
service.OpenAIEndpointCapabilityChatCompletions,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1424,7 +1453,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
inboundEndpoint := GetInboundEndpoint(c)
|
inboundEndpoint := GetInboundEndpoint(c)
|
||||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
@ -1609,10 +1638,11 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
func (h *OpenAIGatewayHandler) submitUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||||
if task == nil {
|
if task == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task = wrapUsageRecordTaskContext(parent, task)
|
||||||
if h.usageRecordWorkerPool != nil {
|
if h.usageRecordWorkerPool != nil {
|
||||||
h.usageRecordWorkerPool.Submit(task)
|
h.usageRecordWorkerPool.Submit(task)
|
||||||
return
|
return
|
||||||
@ -1631,18 +1661,19 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
|||||||
task(ctx)
|
task(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(parent context.Context, result *service.OpenAIForwardResult, task service.UsageRecordTask) {
|
||||||
if result != nil && result.ImageCount > 0 {
|
if result != nil && result.ImageCount > 0 {
|
||||||
h.submitMandatoryUsageRecordTask(task)
|
h.submitMandatoryUsageRecordTask(parent, task)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.submitUsageRecordTask(task)
|
h.submitUsageRecordTask(parent, task)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
|
func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(parent context.Context, task service.UsageRecordTask) {
|
||||||
if task == nil {
|
if task == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
task = wrapUsageRecordTaskContext(parent, task)
|
||||||
if h.usageRecordWorkerPool != nil {
|
if h.usageRecordWorkerPool != nil {
|
||||||
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
|
||||||
return
|
return
|
||||||
@ -1685,10 +1716,10 @@ func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, stream
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
// handleConcurrencyError handles concurrency-related acquire errors.
|
||||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
status, errType, message := concurrencyErrorResponse(err, slotType)
|
||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
h.handleStreamingAwareError(c, status, errType, message, streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||||
|
|||||||
@ -867,8 +867,11 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
|
|||||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||||
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||||
}
|
}
|
||||||
logs := repo.logSnapshot()
|
var logs []service.ContentModerationLog
|
||||||
require.Len(t, logs, 1)
|
require.Eventually(t, func() bool {
|
||||||
|
logs = repo.logSnapshot()
|
||||||
|
return len(logs) == 1
|
||||||
|
}, time.Second, 10*time.Millisecond)
|
||||||
require.True(t, logs[0].Flagged)
|
require.True(t, logs[0].Flagged)
|
||||||
require.Equal(t, service.ContentModerationActionBlock, logs[0].Action)
|
require.Equal(t, service.ContentModerationActionBlock, logs[0].Action)
|
||||||
require.Equal(t, "bad prompt", logs[0].InputExcerpt)
|
require.Equal(t, "bad prompt", logs[0].InputExcerpt)
|
||||||
|
|||||||
@ -0,0 +1,41 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
|
||||||
|
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-request-123")
|
||||||
|
parent = context.WithValue(parent, ctxkey.RequestID, "request-456")
|
||||||
|
|
||||||
|
var gotClientRequestID string
|
||||||
|
var gotRequestID string
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.submitUsageRecordTask(parent, func(ctx context.Context) {
|
||||||
|
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
|
||||||
|
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, "client-request-123", gotClientRequestID)
|
||||||
|
require.Equal(t, "request-456", gotRequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISubmitUsageRecordTaskCopiesRequestContext(t *testing.T) {
|
||||||
|
parent := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-request-123")
|
||||||
|
parent = context.WithValue(parent, ctxkey.RequestID, "openai-request-456")
|
||||||
|
|
||||||
|
var gotClientRequestID string
|
||||||
|
var gotRequestID string
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.submitUsageRecordTask(parent, func(ctx context.Context) {
|
||||||
|
gotClientRequestID, _ = ctx.Value(ctxkey.ClientRequestID).(string)
|
||||||
|
gotRequestID, _ = ctx.Value(ctxkey.RequestID).(string)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, "openai-client-request-123", gotClientRequestID)
|
||||||
|
require.Equal(t, "openai-request-456", gotRequestID)
|
||||||
|
}
|
||||||
@ -311,7 +311,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
|||||||
if result != nil {
|
if result != nil {
|
||||||
upstreamModel = result.UpstreamModel
|
upstreamModel = result.UpstreamModel
|
||||||
}
|
}
|
||||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
h.submitMandatoryUsageRecordTask(c.Request.Context(), func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
|
|||||||
@ -29,7 +29,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
|||||||
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
|
|||||||
h := &GatewayHandler{}
|
h := &GatewayHandler{}
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
if _, ok := ctx.Deadline(); !ok {
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
t.Fatal("expected deadline in fallback context")
|
t.Fatal("expected deadline in fallback context")
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
|
|||||||
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||||
h := &GatewayHandler{}
|
h := &GatewayHandler{}
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
h.submitUsageRecordTask(nil)
|
h.submitUsageRecordTask(context.Background(), nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,12 +66,12 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *t
|
|||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
panic("usage task panic")
|
panic("usage task panic")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
called.Store(true)
|
called.Store(true)
|
||||||
})
|
})
|
||||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
@ -82,7 +82,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
|||||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
|
|||||||
h := &OpenAIGatewayHandler{}
|
h := &OpenAIGatewayHandler{}
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
if _, ok := ctx.Deadline(); !ok {
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
t.Fatal("expected deadline in fallback context")
|
t.Fatal("expected deadline in fallback context")
|
||||||
}
|
}
|
||||||
@ -110,7 +110,7 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *te
|
|||||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||||
h := &OpenAIGatewayHandler{}
|
h := &OpenAIGatewayHandler{}
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
h.submitUsageRecordTask(nil)
|
h.submitUsageRecordTask(context.Background(), nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,12 +119,12 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
|
|||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
panic("usage task panic")
|
panic("usage task panic")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
called.Store(true)
|
called.Store(true)
|
||||||
})
|
})
|
||||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
@ -152,7 +152,7 @@ func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallb
|
|||||||
pool.Submit(func(ctx context.Context) {})
|
pool.Submit(func(ctx context.Context) {})
|
||||||
|
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
|
h.submitMandatoryUsageRecordTask(context.Background(), func(ctx context.Context) {
|
||||||
called.Store(true)
|
called.Store(true)
|
||||||
})
|
})
|
||||||
close(release)
|
close(release)
|
||||||
@ -182,7 +182,7 @@ func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandator
|
|||||||
pool.Submit(func(ctx context.Context) {})
|
pool.Submit(func(ctx context.Context) {})
|
||||||
|
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
h.submitOpenAIUsageRecordTask(context.Background(), &service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
|
||||||
called.Store(true)
|
called.Store(true)
|
||||||
})
|
})
|
||||||
close(release)
|
close(release)
|
||||||
|
|||||||
@ -155,6 +155,7 @@ var claudeModels = []modelDef{
|
|||||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||||
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||||
{ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"},
|
{ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"},
|
||||||
|
{ID: "claude-opus-4-8", DisplayName: "Claude Opus 4.8", CreatedAt: "2026-05-29T00:00:00Z"},
|
||||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
requiredIDs := []string{
|
requiredIDs := []string{
|
||||||
|
"claude-opus-4-8",
|
||||||
"claude-opus-4-6-thinking",
|
"claude-opus-4-6-thinking",
|
||||||
"gemini-2.5-flash-image",
|
"gemini-2.5-flash-image",
|
||||||
"gemini-2.5-flash-image-preview",
|
"gemini-2.5-flash-image-preview",
|
||||||
|
|||||||
@ -204,6 +204,8 @@ type modelInfo struct {
|
|||||||
// 只有在此映射表中的模型才会注入身份提示词
|
// 只有在此映射表中的模型才会注入身份提示词
|
||||||
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
|
// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。
|
||||||
var modelInfoMap = map[string]modelInfo{
|
var modelInfoMap = map[string]modelInfo{
|
||||||
|
"claude-opus-4-8": {DisplayName: "Claude Opus 4.8", CanonicalID: "claude-opus-4-8"},
|
||||||
|
"claude-opus-4-7": {DisplayName: "Claude Opus 4.7", CanonicalID: "claude-opus-4-7"},
|
||||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||||
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
|
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
|
||||||
@ -587,7 +589,8 @@ func maxOutputTokensLimit(model string) int {
|
|||||||
func isAntigravityOpusHighTierModel(model string) bool {
|
func isAntigravityOpusHighTierModel(model string) bool {
|
||||||
lower := strings.ToLower(model)
|
lower := strings.ToLower(model)
|
||||||
return strings.HasPrefix(lower, "claude-opus-4-6") ||
|
return strings.HasPrefix(lower, "claude-opus-4-6") ||
|
||||||
strings.HasPrefix(lower, "claude-opus-4-7")
|
strings.HasPrefix(lower, "claude-opus-4-7") ||
|
||||||
|
strings.HasPrefix(lower, "claude-opus-4-8")
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||||
|
|||||||
@ -1597,3 +1597,139 @@ func TestAnthropicToResponses_TemperatureStrippedForAllGpt5Variants(t *testing.T
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// AnthropicToResponsesResponse: Anthropic input_tokens excludes cached tokens
|
||||||
|
// while OpenAI Responses input_tokens is the total including cached tokens.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAnthropicToResponsesResponse_CacheTokensUseOpenAIInputSemantics(t *testing.T) {
|
||||||
|
resp := &AnthropicResponse{
|
||||||
|
ID: "msg_cache",
|
||||||
|
Model: "claude-sonnet-4-5-20250929",
|
||||||
|
Content: []AnthropicContentBlock{
|
||||||
|
{Type: "text", Text: "ok"},
|
||||||
|
},
|
||||||
|
StopReason: "end_turn",
|
||||||
|
Usage: AnthropicUsage{
|
||||||
|
InputTokens: 3318,
|
||||||
|
OutputTokens: 123,
|
||||||
|
CacheReadInputTokens: 50688,
|
||||||
|
CacheCreationInputTokens: 200,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := AnthropicToResponsesResponse(resp)
|
||||||
|
require.NotNil(t, out.Usage)
|
||||||
|
// 3318 (uncached) + 50688 (read) + 200 (creation) = 54206
|
||||||
|
assert.Equal(t, 54206, out.Usage.InputTokens)
|
||||||
|
assert.Equal(t, 123, out.Usage.OutputTokens)
|
||||||
|
assert.Equal(t, 54329, out.Usage.TotalTokens)
|
||||||
|
require.NotNil(t, out.Usage.InputTokensDetails)
|
||||||
|
assert.Equal(t, 50688, out.Usage.InputTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicToResponsesResponse_NoCacheTokens(t *testing.T) {
|
||||||
|
resp := &AnthropicResponse{
|
||||||
|
ID: "msg_nocache",
|
||||||
|
Model: "claude-sonnet-4-5-20250929",
|
||||||
|
Content: []AnthropicContentBlock{
|
||||||
|
{Type: "text", Text: "ok"},
|
||||||
|
},
|
||||||
|
StopReason: "end_turn",
|
||||||
|
Usage: AnthropicUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := AnthropicToResponsesResponse(resp)
|
||||||
|
require.NotNil(t, out.Usage)
|
||||||
|
assert.Equal(t, 100, out.Usage.InputTokens)
|
||||||
|
assert.Equal(t, 50, out.Usage.OutputTokens)
|
||||||
|
assert.Equal(t, 150, out.Usage.TotalTokens)
|
||||||
|
assert.Nil(t, out.Usage.InputTokensDetails)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicEventToResponses_CacheTokensRoundTripFromMessageStart(t *testing.T) {
|
||||||
|
state := NewAnthropicEventToResponsesState()
|
||||||
|
|
||||||
|
// message_start carries cache fields on the initial Usage object.
|
||||||
|
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
|
||||||
|
Type: "message_start",
|
||||||
|
Message: &AnthropicResponse{
|
||||||
|
ID: "msg_stream_cache",
|
||||||
|
Model: "claude-sonnet-4-5-20250929",
|
||||||
|
Usage: AnthropicUsage{
|
||||||
|
InputTokens: 12,
|
||||||
|
CacheReadInputTokens: 9,
|
||||||
|
CacheCreationInputTokens: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
|
||||||
|
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
|
||||||
|
Type: "message_delta",
|
||||||
|
Usage: &AnthropicUsage{
|
||||||
|
OutputTokens: 7,
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
|
||||||
|
events := AnthropicEventToResponsesEvents(&AnthropicStreamEvent{Type: "message_stop"}, state)
|
||||||
|
|
||||||
|
// The terminal response.completed event must include OpenAI-semantic usage.
|
||||||
|
var completed *ResponsesStreamEvent
|
||||||
|
for i := range events {
|
||||||
|
if events[i].Type == "response.completed" {
|
||||||
|
completed = &events[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, completed, "response.completed event must be emitted")
|
||||||
|
require.NotNil(t, completed.Response)
|
||||||
|
require.NotNil(t, completed.Response.Usage)
|
||||||
|
// 12 (uncached) + 9 (read) + 3 (creation) = 24
|
||||||
|
assert.Equal(t, 24, completed.Response.Usage.InputTokens)
|
||||||
|
assert.Equal(t, 7, completed.Response.Usage.OutputTokens)
|
||||||
|
assert.Equal(t, 31, completed.Response.Usage.TotalTokens)
|
||||||
|
require.NotNil(t, completed.Response.Usage.InputTokensDetails)
|
||||||
|
assert.Equal(t, 9, completed.Response.Usage.InputTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicEventToResponses_CacheTokensFromMessageDelta(t *testing.T) {
|
||||||
|
state := NewAnthropicEventToResponsesState()
|
||||||
|
|
||||||
|
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
|
||||||
|
Type: "message_start",
|
||||||
|
Message: &AnthropicResponse{
|
||||||
|
ID: "msg_delta_cache",
|
||||||
|
Model: "claude-sonnet-4-5-20250929",
|
||||||
|
Usage: AnthropicUsage{InputTokens: 20},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
|
||||||
|
// Some upstreams only emit cache fields on the final message_delta.
|
||||||
|
AnthropicEventToResponsesEvents(&AnthropicStreamEvent{
|
||||||
|
Type: "message_delta",
|
||||||
|
Usage: &AnthropicUsage{
|
||||||
|
OutputTokens: 8,
|
||||||
|
CacheReadInputTokens: 11,
|
||||||
|
CacheCreationInputTokens: 4,
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
|
||||||
|
events := AnthropicEventToResponsesEvents(&AnthropicStreamEvent{Type: "message_stop"}, state)
|
||||||
|
|
||||||
|
var completed *ResponsesStreamEvent
|
||||||
|
for i := range events {
|
||||||
|
if events[i].Type == "response.completed" {
|
||||||
|
completed = &events[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, completed)
|
||||||
|
require.NotNil(t, completed.Response.Usage)
|
||||||
|
// 20 (uncached) + 11 (read) + 4 (creation) = 35
|
||||||
|
assert.Equal(t, 35, completed.Response.Usage.InputTokens)
|
||||||
|
assert.Equal(t, 8, completed.Response.Usage.OutputTokens)
|
||||||
|
require.NotNil(t, completed.Response.Usage.InputTokensDetails)
|
||||||
|
assert.Equal(t, 11, completed.Response.Usage.InputTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
|||||||
@ -95,10 +95,16 @@ func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Usage
|
// Usage
|
||||||
|
// Anthropic's input_tokens excludes cache_read/cache_creation, while OpenAI
|
||||||
|
// Responses' input_tokens is the total including cached tokens. Add them back
|
||||||
|
// when converting so downstream consumers see OpenAI semantics.
|
||||||
|
totalInputTokens := resp.Usage.InputTokens +
|
||||||
|
resp.Usage.CacheReadInputTokens +
|
||||||
|
resp.Usage.CacheCreationInputTokens
|
||||||
out.Usage = &ResponsesUsage{
|
out.Usage = &ResponsesUsage{
|
||||||
InputTokens: resp.Usage.InputTokens,
|
InputTokens: totalInputTokens,
|
||||||
OutputTokens: resp.Usage.OutputTokens,
|
OutputTokens: resp.Usage.OutputTokens,
|
||||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
TotalTokens: totalInputTokens + resp.Usage.OutputTokens,
|
||||||
}
|
}
|
||||||
if resp.Usage.CacheReadInputTokens > 0 {
|
if resp.Usage.CacheReadInputTokens > 0 {
|
||||||
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||||
@ -150,10 +156,13 @@ type AnthropicEventToResponsesState struct {
|
|||||||
CurrentCallID string
|
CurrentCallID string
|
||||||
CurrentName string
|
CurrentName string
|
||||||
|
|
||||||
// Usage from message_delta
|
// Usage from message_start / message_delta. InputTokens here follows
|
||||||
InputTokens int
|
// Anthropic semantics (excludes cached tokens); they are added back when
|
||||||
OutputTokens int
|
// emitting the OpenAI Responses usage.
|
||||||
CacheReadInputTokens int
|
InputTokens int
|
||||||
|
OutputTokens int
|
||||||
|
CacheReadInputTokens int
|
||||||
|
CacheCreationInputTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAnthropicEventToResponsesState returns an initialised stream state.
|
// NewAnthropicEventToResponsesState returns an initialised stream state.
|
||||||
@ -225,6 +234,12 @@ func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEven
|
|||||||
if evt.Message.Usage.InputTokens > 0 {
|
if evt.Message.Usage.InputTokens > 0 {
|
||||||
state.InputTokens = evt.Message.Usage.InputTokens
|
state.InputTokens = evt.Message.Usage.InputTokens
|
||||||
}
|
}
|
||||||
|
if evt.Message.Usage.CacheReadInputTokens > 0 {
|
||||||
|
state.CacheReadInputTokens = evt.Message.Usage.CacheReadInputTokens
|
||||||
|
}
|
||||||
|
if evt.Message.Usage.CacheCreationInputTokens > 0 {
|
||||||
|
state.CacheCreationInputTokens = evt.Message.Usage.CacheCreationInputTokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.CreatedSent {
|
if state.CreatedSent {
|
||||||
@ -392,9 +407,15 @@ func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEven
|
|||||||
// Update usage
|
// Update usage
|
||||||
if evt.Usage != nil {
|
if evt.Usage != nil {
|
||||||
state.OutputTokens = evt.Usage.OutputTokens
|
state.OutputTokens = evt.Usage.OutputTokens
|
||||||
|
if evt.Usage.InputTokens > 0 {
|
||||||
|
state.InputTokens = evt.Usage.InputTokens
|
||||||
|
}
|
||||||
if evt.Usage.CacheReadInputTokens > 0 {
|
if evt.Usage.CacheReadInputTokens > 0 {
|
||||||
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
|
state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens
|
||||||
}
|
}
|
||||||
|
if evt.Usage.CacheCreationInputTokens > 0 {
|
||||||
|
state.CacheCreationInputTokens = evt.Usage.CacheCreationInputTokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -472,10 +493,13 @@ func makeResponsesCompletedEvent(
|
|||||||
seq := state.SequenceNumber
|
seq := state.SequenceNumber
|
||||||
state.SequenceNumber++
|
state.SequenceNumber++
|
||||||
|
|
||||||
|
// Anthropic's input_tokens excludes cache_read/cache_creation; add them
|
||||||
|
// back to match OpenAI Responses semantics where input_tokens is the total.
|
||||||
|
totalInputTokens := state.InputTokens + state.CacheReadInputTokens + state.CacheCreationInputTokens
|
||||||
usage := &ResponsesUsage{
|
usage := &ResponsesUsage{
|
||||||
InputTokens: state.InputTokens,
|
InputTokens: totalInputTokens,
|
||||||
OutputTokens: state.OutputTokens,
|
OutputTokens: state.OutputTokens,
|
||||||
TotalTokens: state.InputTokens + state.OutputTokens,
|
TotalTokens: totalInputTokens + state.OutputTokens,
|
||||||
}
|
}
|
||||||
if state.CacheReadInputTokens > 0 {
|
if state.CacheReadInputTokens > 0 {
|
||||||
usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
usage.InputTokensDetails = &ResponsesInputTokensDetails{
|
||||||
|
|||||||
@ -663,6 +663,115 @@ func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
|
|||||||
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_ReasoningTokens(t *testing.T) {
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_reasoning",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "ping"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 24,
|
||||||
|
OutputTokens: 33,
|
||||||
|
TotalTokens: 57,
|
||||||
|
OutputTokensDetails: &ResponsesOutputTokensDetails{
|
||||||
|
ReasoningTokens: 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-5.5")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
assert.Equal(t, 33, chat.Usage.CompletionTokens)
|
||||||
|
require.NotNil(t, chat.Usage.CompletionTokensDetails)
|
||||||
|
assert.Equal(t, 32, chat.Usage.CompletionTokensDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_AllTokenDetailsPassThrough(t *testing.T) {
|
||||||
|
// Covers the full OpenAI CompletionUsage detail field set so future audio
|
||||||
|
// and prediction-outputs responses propagate without further changes.
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_full_details",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "x"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||||
|
CachedTokens: 60,
|
||||||
|
AudioTokens: 4,
|
||||||
|
},
|
||||||
|
OutputTokensDetails: &ResponsesOutputTokensDetails{
|
||||||
|
ReasoningTokens: 30,
|
||||||
|
AudioTokens: 2,
|
||||||
|
AcceptedPredictionTokens: 10,
|
||||||
|
RejectedPredictionTokens: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-5.5")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
require.NotNil(t, chat.Usage.PromptTokensDetails)
|
||||||
|
assert.Equal(t, 60, chat.Usage.PromptTokensDetails.CachedTokens)
|
||||||
|
assert.Equal(t, 4, chat.Usage.PromptTokensDetails.AudioTokens)
|
||||||
|
|
||||||
|
require.NotNil(t, chat.Usage.CompletionTokensDetails)
|
||||||
|
assert.Equal(t, 30, chat.Usage.CompletionTokensDetails.ReasoningTokens)
|
||||||
|
assert.Equal(t, 2, chat.Usage.CompletionTokensDetails.AudioTokens)
|
||||||
|
assert.Equal(t, 10, chat.Usage.CompletionTokensDetails.AcceptedPredictionTokens)
|
||||||
|
assert.Equal(t, 3, chat.Usage.CompletionTokensDetails.RejectedPredictionTokens)
|
||||||
|
|
||||||
|
raw, err := json.Marshal(chat.Usage)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(raw), `"prompt_tokens_details"`)
|
||||||
|
assert.Contains(t, string(raw), `"completion_tokens_details"`)
|
||||||
|
assert.Contains(t, string(raw), `"reasoning_tokens":30`)
|
||||||
|
assert.Contains(t, string(raw), `"accepted_prediction_tokens":10`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesToChatCompletions_NoReasoningTokensWhenZero(t *testing.T) {
|
||||||
|
// Non-reasoning models do not return reasoning_tokens. The mapping must
|
||||||
|
// omit completion_tokens_details entirely rather than emitting a zero-valued
|
||||||
|
// field, so non-reasoning responses stay clean.
|
||||||
|
resp := &ResponsesResponse{
|
||||||
|
ID: "resp_no_reasoning",
|
||||||
|
Status: "completed",
|
||||||
|
Output: []ResponsesOutput{
|
||||||
|
{
|
||||||
|
Type: "message",
|
||||||
|
Content: []ResponsesContentPart{{Type: "output_text", Text: "hi"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
OutputTokensDetails: &ResponsesOutputTokensDetails{
|
||||||
|
ReasoningTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chat := ResponsesToChatCompletions(resp, "gpt-4o")
|
||||||
|
require.NotNil(t, chat.Usage)
|
||||||
|
assert.Nil(t, chat.Usage.CompletionTokensDetails)
|
||||||
|
|
||||||
|
raw, err := json.Marshal(chat.Usage)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotContains(t, string(raw), "completion_tokens_details")
|
||||||
|
assert.NotContains(t, string(raw), "reasoning_tokens")
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
|
||||||
resp := &ResponsesResponse{
|
resp := &ResponsesResponse{
|
||||||
ID: "resp_ws",
|
ID: "resp_ws",
|
||||||
@ -825,6 +934,32 @@ func TestResponsesEventToChatChunks_Completed(t *testing.T) {
|
|||||||
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_CompletedWithReasoningTokens(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-5.5"
|
||||||
|
state.IncludeUsage = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.completed",
|
||||||
|
Response: &ResponsesResponse{
|
||||||
|
Status: "completed",
|
||||||
|
Usage: &ResponsesUsage{
|
||||||
|
InputTokens: 24,
|
||||||
|
OutputTokens: 33,
|
||||||
|
TotalTokens: 57,
|
||||||
|
OutputTokensDetails: &ResponsesOutputTokensDetails{
|
||||||
|
ReasoningTokens: 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 2)
|
||||||
|
|
||||||
|
require.NotNil(t, chunks[1].Usage)
|
||||||
|
require.NotNil(t, chunks[1].Usage.CompletionTokensDetails)
|
||||||
|
assert.Equal(t, 32, chunks[1].Usage.CompletionTokensDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||||
state := NewResponsesEventToChatState()
|
state := NewResponsesEventToChatState()
|
||||||
state.Model = "gpt-4o"
|
state.Model = "gpt-4o"
|
||||||
|
|||||||
@ -81,19 +81,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
FinishReason: finishReason,
|
FinishReason: finishReason,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
if resp.Usage != nil {
|
out.Usage = chatUsageFromResponsesUsage(resp.Usage)
|
||||||
usage := &ChatUsage{
|
|
||||||
PromptTokens: resp.Usage.InputTokens,
|
|
||||||
CompletionTokens: resp.Usage.OutputTokens,
|
|
||||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
|
||||||
}
|
|
||||||
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
|
|
||||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
|
||||||
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out.Usage = usage
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
@ -341,14 +329,48 @@ func chatUsageFromResponsesUsage(u *ResponsesUsage) *ChatUsage {
|
|||||||
CompletionTokens: u.OutputTokens,
|
CompletionTokens: u.OutputTokens,
|
||||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||||
}
|
}
|
||||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
usage.PromptTokensDetails = promptDetailsFromResponses(u.InputTokensDetails)
|
||||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
usage.CompletionTokensDetails = completionDetailsFromResponses(u.OutputTokensDetails)
|
||||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// promptDetailsFromResponses maps Responses-API input_tokens_details into a
|
||||||
|
// Chat-Completions prompt_tokens_details. Returns nil when nothing would be
|
||||||
|
// emitted, so upstreams that do not break down prompt usage stay clean.
|
||||||
|
func promptDetailsFromResponses(src *ResponsesInputTokensDetails) *ChatTokenDetails {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if src.CachedTokens == 0 && src.AudioTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ChatTokenDetails{
|
||||||
|
CachedTokens: src.CachedTokens,
|
||||||
|
AudioTokens: src.AudioTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// completionDetailsFromResponses maps Responses-API output_tokens_details
|
||||||
|
// into a Chat-Completions completion_tokens_details. Mirrors the OpenAI
|
||||||
|
// official CompletionUsage schema: reasoning_tokens, audio_tokens, and
|
||||||
|
// the predicted-outputs accepted/rejected counts. Returns nil when nothing
|
||||||
|
// would be emitted so non-reasoning, non-audio responses stay clean.
|
||||||
|
func completionDetailsFromResponses(src *ResponsesOutputTokensDetails) *ChatTokenDetails {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if src.ReasoningTokens == 0 && src.AudioTokens == 0 &&
|
||||||
|
src.AcceptedPredictionTokens == 0 && src.RejectedPredictionTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ChatTokenDetails{
|
||||||
|
ReasoningTokens: src.ReasoningTokens,
|
||||||
|
AudioTokens: src.AudioTokens,
|
||||||
|
AcceptedPredictionTokens: src.AcceptedPredictionTokens,
|
||||||
|
RejectedPredictionTokens: src.RejectedPredictionTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||||
return ChatCompletionsChunk{
|
return ChatCompletionsChunk{
|
||||||
ID: state.ID,
|
ID: state.ID,
|
||||||
|
|||||||
@ -362,11 +362,15 @@ func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
|
|||||||
// ResponsesInputTokensDetails breaks down input token usage.
|
// ResponsesInputTokensDetails breaks down input token usage.
|
||||||
type ResponsesInputTokensDetails struct {
|
type ResponsesInputTokensDetails struct {
|
||||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||||
|
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponsesOutputTokensDetails breaks down output token usage.
|
// ResponsesOutputTokensDetails breaks down output token usage.
|
||||||
type ResponsesOutputTokensDetails struct {
|
type ResponsesOutputTokensDetails struct {
|
||||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||||
|
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||||
|
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||||
|
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@ -517,15 +521,27 @@ type ChatChoice struct {
|
|||||||
|
|
||||||
// ChatUsage holds token counts in Chat Completions format.
|
// ChatUsage holds token counts in Chat Completions format.
|
||||||
type ChatUsage struct {
|
type ChatUsage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
|
||||||
|
CompletionTokensDetails *ChatTokenDetails `json:"completion_tokens_details,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatTokenDetails provides a breakdown of token usage.
|
// ChatTokenDetails provides a breakdown of token usage. The same type is
|
||||||
|
// reused for both prompt_tokens_details and completion_tokens_details;
|
||||||
|
// unset fields are omitted so each side only emits the fields that apply.
|
||||||
|
//
|
||||||
|
// Field set mirrors OpenAI's official CompletionUsage schema:
|
||||||
|
// - prompt_tokens_details: cached_tokens, audio_tokens
|
||||||
|
// - completion_tokens_details: reasoning_tokens, audio_tokens,
|
||||||
|
// accepted_prediction_tokens, rejected_prediction_tokens
|
||||||
type ChatTokenDetails struct {
|
type ChatTokenDetails struct {
|
||||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||||
|
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||||
|
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||||
|
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||||
|
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
|
||||||
|
|||||||
@ -446,6 +446,12 @@ var DefaultModels = []Model{
|
|||||||
DisplayName: "Claude Opus 4.7",
|
DisplayName: "Claude Opus 4.7",
|
||||||
CreatedAt: "2026-04-17T00:00:00Z",
|
CreatedAt: "2026-04-17T00:00:00Z",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-8",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: "Claude Opus 4.8",
|
||||||
|
CreatedAt: "2026-05-29T00:00:00Z",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4-6",
|
ID: "claude-sonnet-4-6",
|
||||||
Type: "model",
|
Type: "model",
|
||||||
|
|||||||
78
backend/internal/pkg/openai/allowed_client.go
Normal file
78
backend/internal/pkg/openai/allowed_client.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// 命名预设 ID。账号侧 codex_cli_only_allowed_clients 只能引用这些预设键,
|
||||||
|
// 具体匹配规则固化在下方 registry 中,配置只能「选择启用哪些预设」、不能自定义规则,
|
||||||
|
// 以防该白名单退化为可任意放宽的后门。
|
||||||
|
const (
|
||||||
|
// AllowedClientClaudeCode 对应 Claude Code CLI 的 codex 插件。
|
||||||
|
AllowedClientClaudeCode = "claude_code"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AllowedClientEntry 描述一个被额外放行的非官方 Codex 客户端签名。
|
||||||
|
// Originator 必须精确等值匹配(归一化后)。
|
||||||
|
// UAContains 为必填字段:列表为空,或列表中存在任何空白 marker,均视为非法配置,
|
||||||
|
// 整体安全失败(return false);每一项都必须出现在 User-Agent 中。
|
||||||
|
// 这确保双因子匹配不会因缺失 UA 声明而退化为仅凭可伪造的 originator 单因子放行。
|
||||||
|
type AllowedClientEntry struct {
|
||||||
|
Originator string
|
||||||
|
UAContains []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowedClientRegistry 固化各命名预设的签名规则。
|
||||||
|
//
|
||||||
|
// Claude Code codex 插件签名来源:插件以 clientInfo.name="Claude Code" 完成 app-server
|
||||||
|
// initialize 握手,codex 据此把 originator 设为 "Claude Code",User-Agent 前缀同样为
|
||||||
|
// "Claude Code/"(两者同源)。若上游 Claude Code 插件更改 clientInfo.name,此处需同步更新。
|
||||||
|
var allowedClientRegistry = map[string]AllowedClientEntry{
|
||||||
|
AllowedClientClaudeCode: {
|
||||||
|
Originator: "Claude Code",
|
||||||
|
UAContains: []string{"Claude Code/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllowedClientMatch 判断请求头是否命中给定的额外客户端签名。
|
||||||
|
// originator 必须精确等值(归一化后);UAContains 中每一项都必须出现在 UA 中。
|
||||||
|
// UAContains 为必填:列表为空或含任何空白 marker 均视为非法配置,整体安全失败。
|
||||||
|
func IsAllowedClientMatch(userAgent, originator string, entry AllowedClientEntry) bool {
|
||||||
|
wantOriginator := normalizeCodexClientHeader(entry.Originator)
|
||||||
|
if wantOriginator == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if normalizeCodexClientHeader(originator) != wantOriginator {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 预设必须声明 UA 特征:否则将退化为仅凭可伪造的 originator 单因子匹配。
|
||||||
|
if len(entry.UAContains) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ua := normalizeCodexClientHeader(userAgent)
|
||||||
|
for _, marker := range entry.UAContains {
|
||||||
|
normalizedMarker := normalizeCodexClientHeader(marker)
|
||||||
|
if normalizedMarker == "" {
|
||||||
|
// 空白 marker 让该项失去校验能力,会让双因子退化为仅 originator
|
||||||
|
// 单因子;视为非法配置,安全失败。
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !strings.Contains(ua, normalizedMarker) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchAllowedClients 判断请求头是否命中 clientIDs 引用的任一预设签名。
|
||||||
|
// 未知预设 ID 会被忽略;空列表恒不放行(默认拒绝)。
|
||||||
|
func MatchAllowedClients(userAgent, originator string, clientIDs []string) bool {
|
||||||
|
for _, id := range clientIDs {
|
||||||
|
entry, ok := allowedClientRegistry[normalizeCodexClientHeader(id)]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if IsAllowedClientMatch(userAgent, originator, entry) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
95
backend/internal/pkg/openai/allowed_client_test.go
Normal file
95
backend/internal/pkg/openai/allowed_client_test.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
// 真实的 Claude Code codex 插件请求头:originator 与 UA 前缀同源于 clientInfo.name="Claude Code"。
|
||||||
|
const (
|
||||||
|
testClaudeCodeOriginator = "Claude Code"
|
||||||
|
testClaudeCodeUserAgent = "Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsAllowedClientMatch(t *testing.T) {
|
||||||
|
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{"Claude Code/"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ua string
|
||||||
|
originator string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "真实签名命中", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, want: true},
|
||||||
|
{name: "大小写不敏感", ua: "claude code/0.5.0 (macos)", originator: "claude code", want: true},
|
||||||
|
{name: "originator 两侧空白被裁剪", ua: testClaudeCodeUserAgent, originator: " Claude Code ", want: true},
|
||||||
|
{name: "originator 非精确(带后缀)不命中", ua: testClaudeCodeUserAgent, originator: "Claude Code Extra", want: false},
|
||||||
|
{name: "originator 为空不命中", ua: testClaudeCodeUserAgent, originator: "", want: false},
|
||||||
|
{name: "originator 是官方 codex 不命中", ua: testClaudeCodeUserAgent, originator: "codex_cli_rs", want: false},
|
||||||
|
{name: "UA 缺少 Claude Code/ 标记不命中", ua: "curl/8.0", originator: testClaudeCodeOriginator, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsAllowedClientMatch(tt.ua, tt.originator, entry); got != tt.want {
|
||||||
|
t.Fatalf("IsAllowedClientMatch(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedClientMatch_EmptyOriginatorEntryNeverMatches(t *testing.T) {
|
||||||
|
// registry 条目若没有配置 Originator,绝不放行,避免成为宽松后门。
|
||||||
|
entry := AllowedClientEntry{Originator: "", UAContains: []string{"Claude Code/"}}
|
||||||
|
if IsAllowedClientMatch(testClaudeCodeUserAgent, "", entry) {
|
||||||
|
t.Fatal("空 Originator 的条目不应匹配任何请求")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedClientMatch_EmptyUAContainsNeverMatches(t *testing.T) {
|
||||||
|
// 预设必须声明 UA 特征,否则退化为仅凭可伪造的 originator 单因子匹配,绝不放行。
|
||||||
|
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: nil}
|
||||||
|
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
|
||||||
|
t.Fatal("未声明 UA 特征的预设不应匹配,避免退化为单因子 originator 匹配")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedClientMatch_WhitespaceUAMarkerNeverMatches(t *testing.T) {
|
||||||
|
// 全空白 marker 归一化后为空,若被跳过则退化为仅 originator 单因子;
|
||||||
|
// 任何空白 marker 视为非法预设配置,必须安全失败。
|
||||||
|
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{" "}}
|
||||||
|
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
|
||||||
|
t.Fatal("UAContains 含全空白 marker 不应匹配,避免退化为单因子 originator 匹配")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedClientMatch_MixedEmptyUAMarkerNeverMatches(t *testing.T) {
|
||||||
|
// 即便 UAContains 含一个真实 marker,只要其中混入任何空白 marker 也视为非法配置;
|
||||||
|
// 防止维护者只为对齐凑数而插入空字符串。
|
||||||
|
entry := AllowedClientEntry{Originator: "Claude Code", UAContains: []string{"", "Claude Code/"}}
|
||||||
|
if IsAllowedClientMatch(testClaudeCodeUserAgent, testClaudeCodeOriginator, entry) {
|
||||||
|
t.Fatal("UAContains 混入空白 marker 不应匹配")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAllowedClients(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ua string
|
||||||
|
originator string
|
||||||
|
clientIDs []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "claude_code 预设命中真实签名", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{AllowedClientClaudeCode}, want: true},
|
||||||
|
{name: "claude_code 预设 + 伪造 originator 不命中", ua: testClaudeCodeUserAgent, originator: "my_client", clientIDs: []string{AllowedClientClaudeCode}, want: false},
|
||||||
|
{name: "空列表不放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: nil, want: false},
|
||||||
|
{name: "未知预设 ID 不放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{"unknown_client"}, want: false},
|
||||||
|
{name: "ID 大小写/空白容错", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{" Claude_Code "}, want: true},
|
||||||
|
{name: "多预设任一命中即放行", ua: testClaudeCodeUserAgent, originator: testClaudeCodeOriginator, clientIDs: []string{"unknown_client", AllowedClientClaudeCode}, want: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := MatchAllowedClients(tt.ua, tt.originator, tt.clientIDs); got != tt.want {
|
||||||
|
t.Fatalf("MatchAllowedClients(%q, %q, %v) = %v, want %v", tt.ua, tt.originator, tt.clientIDs, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -548,12 +548,17 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
|
|||||||
"openai_ws_force_http",
|
"openai_ws_force_http",
|
||||||
"openai_responses_mode",
|
"openai_responses_mode",
|
||||||
"openai_responses_supported",
|
"openai_responses_supported",
|
||||||
// model_rate_limits 必须进入调度快照:SetModelRateLimit 写入的模型级冷却
|
"codex_5h_used_percent",
|
||||||
// 时间戳(accounts.extra.model_rate_limits.<modelKey>.rate_limit_reset_at)
|
"codex_7d_used_percent",
|
||||||
// 是 isAccountSchedulableForModelSelection/IsSchedulableForModelWithContext
|
"codex_5h_reset_at",
|
||||||
// 过滤候选账号的唯一依据。缺失会导致已限流账号被反复选中,触发 failover 切号环。
|
"codex_7d_reset_at",
|
||||||
// 与 service.modelRateLimitsKey 常量保持字面量一致。
|
"codex_5h_reset_after_seconds",
|
||||||
"model_rate_limits",
|
"codex_7d_reset_after_seconds",
|
||||||
|
"codex_usage_updated_at",
|
||||||
|
"auto_pause_5h_threshold",
|
||||||
|
"auto_pause_7d_threshold",
|
||||||
|
"auto_pause_5h_disabled",
|
||||||
|
"auto_pause_7d_disabled",
|
||||||
}
|
}
|
||||||
filtered := make(map[string]any)
|
filtered := make(map[string]any)
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
|||||||
@ -100,3 +100,36 @@ func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
|
|||||||
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
|
require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
|
||||||
require.Nil(t, got.Groups)
|
require.Nil(t, got.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildSchedulerMetadataAccount_KeepsQuotaAutoPauseFields(t *testing.T) {
|
||||||
|
account := service.Account{
|
||||||
|
ID: 88,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 12.34,
|
||||||
|
"codex_7d_used_percent": 56.78,
|
||||||
|
"codex_5h_reset_at": "2026-05-29T10:00:00Z",
|
||||||
|
"codex_7d_reset_at": "2026-06-01T10:00:00Z",
|
||||||
|
"codex_5h_reset_after_seconds": 300,
|
||||||
|
"codex_7d_reset_after_seconds": 600,
|
||||||
|
"codex_usage_updated_at": "2026-05-29T09:00:00Z",
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
"auto_pause_7d_threshold": 0.96,
|
||||||
|
"auto_pause_5h_disabled": true,
|
||||||
|
"auto_pause_7d_disabled": false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := buildSchedulerMetadataAccount(account)
|
||||||
|
|
||||||
|
require.Equal(t, 12.34, got.Extra["codex_5h_used_percent"])
|
||||||
|
require.Equal(t, 56.78, got.Extra["codex_7d_used_percent"])
|
||||||
|
require.Equal(t, "2026-05-29T10:00:00Z", got.Extra["codex_5h_reset_at"])
|
||||||
|
require.Equal(t, "2026-06-01T10:00:00Z", got.Extra["codex_7d_reset_at"])
|
||||||
|
require.Equal(t, 300, got.Extra["codex_5h_reset_after_seconds"])
|
||||||
|
require.Equal(t, 600, got.Extra["codex_7d_reset_after_seconds"])
|
||||||
|
require.Equal(t, "2026-05-29T09:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||||
|
require.Equal(t, 0.95, got.Extra["auto_pause_5h_threshold"])
|
||||||
|
require.Equal(t, 0.96, got.Extra["auto_pause_7d_threshold"])
|
||||||
|
require.Equal(t, true, got.Extra["auto_pause_5h_disabled"])
|
||||||
|
require.Equal(t, false, got.Extra["auto_pause_7d_disabled"])
|
||||||
|
}
|
||||||
|
|||||||
@ -843,6 +843,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": true,
|
"openai_advanced_scheduler_enabled": true,
|
||||||
"openai_codex_user_agent": "",
|
"openai_codex_user_agent": "",
|
||||||
|
"openai_allow_claude_code_codex_plugin": false,
|
||||||
"openai_fast_policy_settings": {
|
"openai_fast_policy_settings": {
|
||||||
"rules": []
|
"rules": []
|
||||||
},
|
},
|
||||||
@ -1079,6 +1080,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"payment_visible_method_wxpay_enabled": false,
|
"payment_visible_method_wxpay_enabled": false,
|
||||||
"openai_advanced_scheduler_enabled": false,
|
"openai_advanced_scheduler_enabled": false,
|
||||||
"openai_codex_user_agent": "",
|
"openai_codex_user_agent": "",
|
||||||
|
"openai_allow_claude_code_codex_plugin": false,
|
||||||
"openai_fast_policy_settings": {
|
"openai_fast_policy_settings": {
|
||||||
"rules": []
|
"rules": []
|
||||||
},
|
},
|
||||||
|
|||||||
@ -11,6 +11,8 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const clientRequestIDHeader = "X-Client-Request-ID"
|
||||||
|
|
||||||
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||||
//
|
//
|
||||||
// This is used by the Ops monitoring module for end-to-end request correlation.
|
// This is used by the Ops monitoring module for end-to-end request correlation.
|
||||||
@ -21,12 +23,14 @@ func ClientRequestID() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
|
if v, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(v) != "" {
|
||||||
|
c.Header(clientRequestIDHeader, strings.TrimSpace(v))
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
|
c.Header(clientRequestIDHeader, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
||||||
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
||||||
ctx = logger.IntoContext(ctx, requestLogger)
|
ctx = logger.IntoContext(ctx, requestLogger)
|
||||||
|
|||||||
50
backend/internal/server/middleware/client_request_id_test.go
Normal file
50
backend/internal/server/middleware/client_request_id_test.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientRequestIDGeneratesAndExposesID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ClientRequestID())
|
||||||
|
router.GET("/", func(c *gin.Context) {
|
||||||
|
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
|
c.String(http.StatusOK, value)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.NotEmpty(t, w.Body.String())
|
||||||
|
require.Equal(t, w.Body.String(), w.Header().Get(clientRequestIDHeader))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientRequestIDPreservesExistingContextID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ClientRequestID())
|
||||||
|
router.GET("/", func(c *gin.Context) {
|
||||||
|
value, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
|
c.String(http.StatusOK, value)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "existing-client-request-id"))
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
require.Equal(t, "existing-client-request-id", w.Body.String())
|
||||||
|
require.Equal(t, "existing-client-request-id", w.Header().Get(clientRequestIDHeader))
|
||||||
|
}
|
||||||
@ -66,6 +66,15 @@ type Account struct {
|
|||||||
modelMappingCacheRawSig uint64
|
modelMappingCacheRawSig uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAIEndpointCapability string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenAIEndpointCapabilityChatCompletions OpenAIEndpointCapability = "chat_completions"
|
||||||
|
OpenAIEndpointCapabilityEmbeddings OpenAIEndpointCapability = "embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openAIEndpointCapabilitiesCredentialKey = "openai_capabilities"
|
||||||
|
|
||||||
type TempUnschedulableRule struct {
|
type TempUnschedulableRule struct {
|
||||||
ErrorCode int `json:"error_code"`
|
ErrorCode int `json:"error_code"`
|
||||||
Keywords []string `json:"keywords"`
|
Keywords []string `json:"keywords"`
|
||||||
@ -1153,6 +1162,80 @@ func (a *Account) GetOpenAISessionID() string {
|
|||||||
return strings.TrimSpace(a.GetExtraString("openai_session_id"))
|
return strings.TrimSpace(a.GetExtraString("openai_session_id"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) SupportsOpenAIEndpointCapability(capability OpenAIEndpointCapability) bool {
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if capability == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !a.IsOpenAI() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch capability {
|
||||||
|
case OpenAIEndpointCapabilityChatCompletions:
|
||||||
|
case OpenAIEndpointCapabilityEmbeddings:
|
||||||
|
if a.Type != AccountTypeAPIKey {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
configured, found := a.openAIEndpointCapabilitySet()
|
||||||
|
if !found {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return configured[string(capability)]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) openAIEndpointCapabilitySet() (map[string]bool, bool) {
|
||||||
|
if a == nil || a.Credentials == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
raw, found := a.Credentials[openAIEndpointCapabilitiesCredentialKey]
|
||||||
|
if !found || raw == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]bool)
|
||||||
|
add := func(value string) {
|
||||||
|
value = strings.ToLower(strings.TrimSpace(value))
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result[value] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch capabilities := raw.(type) {
|
||||||
|
case []any:
|
||||||
|
for _, item := range capabilities {
|
||||||
|
if value, ok := item.(string); ok {
|
||||||
|
add(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, value := range capabilities {
|
||||||
|
add(value)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
for key, value := range capabilities {
|
||||||
|
enabled, ok := value.(bool)
|
||||||
|
if ok && enabled {
|
||||||
|
add(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case map[string]bool:
|
||||||
|
for key, enabled := range capabilities {
|
||||||
|
if enabled {
|
||||||
|
add(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, true
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
|
func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
|
||||||
if !a.IsOpenAI() {
|
if !a.IsOpenAI() {
|
||||||
return false
|
return false
|
||||||
@ -1473,6 +1556,38 @@ func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
|||||||
return ok && enabled
|
return ok && enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCodexCLIOnlyAllowedClients 返回 codex_cli_only 之上额外放行的命名客户端预设 ID 列表。
|
||||||
|
// 仅 OpenAI OAuth 账号生效;缺失或类型不符时返回空。预设 ID 的具体匹配规则由
|
||||||
|
// openai 包的 registry 固化,配置只能引用预设键、不能自定义规则。
|
||||||
|
func (a *Account) GetCodexCLIOnlyAllowedClients() []string {
|
||||||
|
if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw, ok := a.Extra["codex_cli_only_allowed_clients"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case []string:
|
||||||
|
result := make([]string, 0, len(v))
|
||||||
|
for _, s := range v {
|
||||||
|
if strings.TrimSpace(s) != "" {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
case []any:
|
||||||
|
result := make([]string, 0, len(v))
|
||||||
|
for _, item := range v {
|
||||||
|
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// WindowCostSchedulability 窗口费用调度状态
|
// WindowCostSchedulability 窗口费用调度状态
|
||||||
type WindowCostSchedulability int
|
type WindowCostSchedulability int
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,68 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccount_GetCodexCLIOnlyAllowedClients(t *testing.T) {
|
||||||
|
t.Run("OAuth 账号读取 []any 字符串列表", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
|
||||||
|
}
|
||||||
|
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OAuth 账号读取 []string 列表", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []string{"claude_code"}},
|
||||||
|
}
|
||||||
|
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("[]string 跳过空白元素", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []string{"claude_code", "", " "}},
|
||||||
|
}
|
||||||
|
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("跳过非字符串与空白元素", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code", 123, "", " "}},
|
||||||
|
}
|
||||||
|
require.Equal(t, []string{"claude_code"}, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非 OAuth 账号返回空", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
|
||||||
|
}
|
||||||
|
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Extra 为空返回空", func(t *testing.T) {
|
||||||
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||||
|
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("字段缺失返回空", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
}
|
||||||
|
require.Empty(t, account.GetCodexCLIOnlyAllowedClients())
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -4209,6 +4209,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
// 构建上游请求 URL
|
// 构建上游请求 URL
|
||||||
upstreamURL := baseURL + "/v1/messages"
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
|
|
||||||
|
// 能力维度 sanitize:Anthropic-compatible 上游透传路径也需要保证 body↔beta header
|
||||||
|
// 对称。客户端 anthropic-beta header 不含 context-management-2025-06-27 但 body 带
|
||||||
|
// context_management 时 strip,与 Anthropic 直连 / Bedrock / Vertex 路径保持一致。
|
||||||
|
clientBeta := c.GetHeader("anthropic-beta")
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
|
||||||
|
body = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
// 创建请求
|
// 创建请求
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -4224,7 +4232,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||||
req.Header.Set("anthropic-version", v)
|
req.Header.Set("anthropic-version", v)
|
||||||
}
|
}
|
||||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
if v := clientBeta; v != "" {
|
||||||
req.Header.Set("anthropic-beta", v)
|
req.Header.Set("anthropic-beta", v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
|||||||
accountMapping: nil,
|
accountMapping: nil,
|
||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-5",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "默认映射透传 - claude-opus-4-8",
|
||||||
|
requestedModel: "claude-opus-4-8",
|
||||||
|
accountMapping: nil,
|
||||||
|
expected: "claude-opus-4-8",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "默认映射透传 - claude-opus-4-7",
|
||||||
|
requestedModel: "claude-opus-4-7",
|
||||||
|
accountMapping: nil,
|
||||||
|
expected: "claude-opus-4-7",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "默认映射透传 - claude-opus-4-6-thinking",
|
name: "默认映射透传 - claude-opus-4-6-thinking",
|
||||||
requestedModel: "claude-opus-4-6-thinking",
|
requestedModel: "claude-opus-4-6-thinking",
|
||||||
@ -210,6 +222,7 @@ func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
|
|||||||
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
||||||
|
|
||||||
// 可映射(有明确前缀映射)
|
// 可映射(有明确前缀映射)
|
||||||
|
{"可映射 - claude-opus-4-8", "claude-opus-4-8", true},
|
||||||
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
|
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
|
||||||
|
|
||||||
// 前缀透传(claude 和 gemini 前缀)
|
// 前缀透传(claude 和 gemini 前缀)
|
||||||
|
|||||||
@ -174,6 +174,7 @@ func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
|||||||
expect bool
|
expect bool
|
||||||
}{
|
}{
|
||||||
{"us.anthropic.claude-opus-4-6-v1", true},
|
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||||
|
{"us.anthropic.claude-opus-4-8-v1", true},
|
||||||
{"us.anthropic.claude-sonnet-4-6", true},
|
{"us.anthropic.claude-sonnet-4-6", true},
|
||||||
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||||
@ -511,6 +512,20 @@ func TestResolveBedrockModelID(t *testing.T) {
|
|||||||
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("default opus 4.8 mapping uses regional Bedrock model id", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "eu-west-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-8")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "eu.anthropic.claude-opus-4-8-v1", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformAnthropic,
|
Platform: PlatformAnthropic,
|
||||||
@ -714,6 +729,7 @@ func TestIsBedrockOpus47OrNewer(t *testing.T) {
|
|||||||
modelID string
|
modelID string
|
||||||
expect bool
|
expect bool
|
||||||
}{
|
}{
|
||||||
|
{"us.anthropic.claude-opus-4-8-v1", true},
|
||||||
{"us.anthropic.claude-opus-4-7-v1", true},
|
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||||
{"us.anthropic.claude-opus-4-6-v1", false},
|
{"us.anthropic.claude-opus-4-6-v1", false},
|
||||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", false},
|
{"us.anthropic.claude-opus-4-5-20251101-v1:0", false},
|
||||||
@ -886,10 +902,12 @@ func TestIsBedrockOpus47OrNewer_EdgeCases(t *testing.T) {
|
|||||||
modelID string
|
modelID string
|
||||||
expect bool
|
expect bool
|
||||||
}{
|
}{
|
||||||
|
{"anthropic.claude-opus-4-8-v1", true},
|
||||||
{"anthropic.claude-opus-4-7-v1", true},
|
{"anthropic.claude-opus-4-7-v1", true},
|
||||||
{"us.anthropic.claude-opus-4-7-20270101-v1:0", true},
|
{"us.anthropic.claude-opus-4-7-20270101-v1:0", true},
|
||||||
{"", false},
|
{"", false},
|
||||||
// Forward() passes parsed.Model (standard names), not Bedrock IDs
|
// Forward() passes parsed.Model (standard names), not Bedrock IDs
|
||||||
|
{"claude-opus-4-8", true},
|
||||||
{"claude-opus-4-7", true},
|
{"claude-opus-4-7", true},
|
||||||
{"claude-opus-4-6", false},
|
{"claude-opus-4-6", false},
|
||||||
{"claude-sonnet-4-7", false},
|
{"claude-sonnet-4-7", false},
|
||||||
|
|||||||
@ -432,6 +432,9 @@ const (
|
|||||||
// 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
|
// 当客户端 UA 被识别为浏览器(Chrome/Firefox/Safari/Edge 等)时,转发给 OpenAI 上游前会替换为此值,
|
||||||
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
|
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
|
||||||
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
|
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
|
||||||
|
// SettingKeyOpenAIAllowClaudeCodeCodexPlugin 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认 false)。
|
||||||
|
// 仅在账号 codex_cli_only 开启时生效;开启后无需逐账号配置 codex_cli_only_allowed_clients。
|
||||||
|
SettingKeyOpenAIAllowClaudeCodeCodexPlugin = "openai_allow_claude_code_codex_plugin"
|
||||||
|
|
||||||
// 余额不足提醒
|
// 余额不足提醒
|
||||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||||
|
|||||||
@ -476,6 +476,66 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFie
|
|||||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFiltersGenerationFields(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"tool","input_schema":{"type":"object"}}],"temperature":0.7,"top_p":0.9,"top_k":40,"stream":true,"stop_sequences":["END"],"max_tokens":1024,"thinking":{"type":"enabled","budget_tokens":5000}}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: "claude-sonnet-4-20250514",
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":42}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 302,
|
||||||
|
Name: "count-token-filter-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sentBody := upstream.lastBody
|
||||||
|
require.False(t, gjson.GetBytes(sentBody, "temperature").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(sentBody, "top_p").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(sentBody, "top_k").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(sentBody, "stream").Exists())
|
||||||
|
require.False(t, gjson.GetBytes(sentBody, "stop_sequences").Exists())
|
||||||
|
require.Equal(t, "claude-sonnet-4-20250514", gjson.GetBytes(sentBody, "model").String())
|
||||||
|
require.Equal(t, "sys", gjson.GetBytes(sentBody, "system.0.text").String())
|
||||||
|
require.Equal(t, "hello", gjson.GetBytes(sentBody, "messages.0.content").String())
|
||||||
|
require.Equal(t, "tool", gjson.GetBytes(sentBody, "tools.0.name").String())
|
||||||
|
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int())
|
||||||
|
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||||
// 确保空模型名不会触发映射逻辑
|
// 确保空模型名不会触发映射逻辑
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||||
|
|||||||
@ -66,3 +66,67 @@ func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Vertex 路径回归保护:同样需要
|
||||||
|
// body↔beta header 能力维度对称。客户端 header 不带 context-management beta
|
||||||
|
// 但 body 带 context_management 字段 → Vertex builder 必须 strip 字段,与 Anthropic
|
||||||
|
// 直连 / Bedrock 路径保持一致。
|
||||||
|
func TestGatewayService_BuildAnthropicVertexServiceAccount_StripsContextManagementWhenBetaMissing(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
// 客户端 header 只带 interleaved-thinking,不带 context-management-2025-06-27
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 302, Platform: PlatformAnthropic, Type: AccountTypeServiceAccount,
|
||||||
|
Credentials: map[string]any{"project_id": "vertex-proj", "location": "us-east5"},
|
||||||
|
}
|
||||||
|
// body 带了 context_management 字段(客户端透传 / normalize 补齐 / mimicry 注入等场景都可能导致)
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
svc := &GatewayService{}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"vertex-token", "service_account", "claude-haiku-4-5@20251001", false, false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got := readRequestBodyForTest(t, req)
|
||||||
|
require.False(t, gjson.GetBytes(got, "context_management").Exists(),
|
||||||
|
"Vertex 路径下客户端 header 缺 context-management beta 时,必须 strip body 同名字段")
|
||||||
|
// header 对称断言:覆盖未来某人在 Vertex builder 里加入与 sanitize 不一致的 header 处理。
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
require.False(t, anthropicBetaTokensContains(outBeta, "context-management-2025-06-27"),
|
||||||
|
"与 body 对称:outgoing anthropic-beta header 也不含 context-management beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vertex 路径反面:客户端 header 含 context-management beta 时保留字段。
|
||||||
|
func TestGatewayService_BuildAnthropicVertexServiceAccount_PreservesContextManagementWhenBetaPresent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14,context-management-2025-06-27")
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 303, Platform: PlatformAnthropic, Type: AccountTypeServiceAccount,
|
||||||
|
Credentials: map[string]any{"project_id": "vertex-proj", "location": "us-east5"},
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
|
||||||
|
svc := &GatewayService{}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"vertex-token", "service_account", "claude-sonnet-4-6@20260218", false, false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got := readRequestBodyForTest(t, req)
|
||||||
|
require.True(t, gjson.GetBytes(got, "context_management").Exists(),
|
||||||
|
"Vertex + 客户端 header 包含 context-management beta 时字段必须保留")
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
require.True(t, anthropicBetaTokensContains(outBeta, "context-management-2025-06-27"),
|
||||||
|
"与 body 对称:outgoing anthropic-beta header 同步含 context-management beta")
|
||||||
|
}
|
||||||
|
|||||||
667
backend/internal/service/gateway_context_management_test.go
Normal file
667
backend/internal/service/gateway_context_management_test.go
Normal file
@ -0,0 +1,667 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// 背景
|
||||||
|
// ============================================================================
|
||||||
|
//
|
||||||
|
// Anthropic 上游对 body.context_management 字段实施 Pydantic schema 校验:
|
||||||
|
// 当且仅当 anthropic-beta header 含 context-management-2025-06-27 时接受。
|
||||||
|
// 否则报:
|
||||||
|
// "context_management: Extra inputs are not permitted"
|
||||||
|
//
|
||||||
|
// 本仓采用能力维度对称约束(与 Bedrock 路径的 sanitizeBedrockFieldsForBetaTokens
|
||||||
|
// 对称):在所有 Anthropic 直连出口,按最终 anthropic-beta header 是否含上述 token
|
||||||
|
// 决定 body 是否保留同名字段。
|
||||||
|
//
|
||||||
|
// 本文件覆盖:
|
||||||
|
// 1) sanitizeAnthropicBodyForBetaTokens 纯函数
|
||||||
|
// 2) anthropicBetaTokensContains 解析辅助函数
|
||||||
|
// 3) computeFinalAnthropicBeta / computeFinalCountTokensAnthropicBeta 各路径
|
||||||
|
// 4) normalizeClaudeOAuthRequestBody 的 context_management 补齐行为(不再按 model 短路)
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// anthropicBetaTokensContains
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestAnthropicBetaTokensContains_EmptyInputs(t *testing.T) {
|
||||||
|
require.False(t, anthropicBetaTokensContains("", "context-management-2025-06-27"))
|
||||||
|
require.False(t, anthropicBetaTokensContains("oauth-2025-04-20", ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicBetaTokensContains_SingleToken(t *testing.T) {
|
||||||
|
require.True(t, anthropicBetaTokensContains("context-management-2025-06-27", "context-management-2025-06-27"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicBetaTokensContains_MultiTokenComma(t *testing.T) {
|
||||||
|
header := "oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14"
|
||||||
|
require.True(t, anthropicBetaTokensContains(header, "context-management-2025-06-27"))
|
||||||
|
require.True(t, anthropicBetaTokensContains(header, "oauth-2025-04-20"))
|
||||||
|
require.False(t, anthropicBetaTokensContains(header, "fast-mode-2026-02-01"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicBetaTokensContains_ToleratesWhitespace(t *testing.T) {
|
||||||
|
header := "oauth-2025-04-20 , context-management-2025-06-27 , interleaved-thinking-2025-05-14"
|
||||||
|
require.True(t, anthropicBetaTokensContains(header, "context-management-2025-06-27"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicBetaTokensContains_SubstringNotMatched(t *testing.T) {
|
||||||
|
// 严格 token 比较,不应被子串误匹配
|
||||||
|
require.False(t, anthropicBetaTokensContains("context-management-2025-06-27-rev2", "context-management-2025-06-27"),
|
||||||
|
"必须按 token 边界匹配,不允许 prefix 子串误命中")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// sanitizeAnthropicBodyForBetaTokens
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_NoFieldNoChange(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","messages":[]}`)
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "oauth-2025-04-20")
|
||||||
|
require.False(t, changed)
|
||||||
|
require.Equal(t, string(body), string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_FieldKeptWhenBetaPresent(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-opus-4-7","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens(body,
|
||||||
|
"oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14")
|
||||||
|
require.False(t, changed)
|
||||||
|
require.True(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
require.Equal(t, "clear_thinking_20251015",
|
||||||
|
gjson.GetBytes(out, "context_management.edits.0.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_FieldStrippedWhenBetaMissing(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "oauth-2025-04-20,interleaved-thinking-2025-05-14")
|
||||||
|
require.True(t, changed)
|
||||||
|
require.False(t, gjson.GetBytes(out, "context_management").Exists(),
|
||||||
|
"header 不含 context-management beta 时必须 strip 同名字段")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_FieldStrippedWhenBetaEmpty(t *testing.T) {
|
||||||
|
body := []byte(`{"context_management":{"edits":[]},"messages":[]}`)
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens(body, "")
|
||||||
|
require.True(t, changed)
|
||||||
|
require.False(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_EmptyBody(t *testing.T) {
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens([]byte{}, "")
|
||||||
|
require.False(t, changed)
|
||||||
|
require.Empty(t, out)
|
||||||
|
|
||||||
|
out, changed = sanitizeAnthropicBodyForBetaTokens(nil, "")
|
||||||
|
require.False(t, changed)
|
||||||
|
require.Empty(t, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ★ 关键回归断言:能力维度 sanitize 解决了 "真 CC + haiku" 路径的过度删除问题。
|
||||||
|
// 真实 Claude Code CLI 2.1.87+ 客户端 header 含 context-management beta;
|
||||||
|
// 即使 model 是 haiku,sanitize 也不应剥离功能字段。
|
||||||
|
func TestSanitizeAnthropicBodyForBetaTokens_HaikuRealCCClientPreservesField(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
|
||||||
|
// 真 Claude Code CLI 2.1.87+ 客户端 header 含 context-management beta
|
||||||
|
clientBeta := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27"
|
||||||
|
out, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta)
|
||||||
|
require.False(t, changed,
|
||||||
|
"真 CC 客户端 header 含 context-management beta 时,haiku body 字段必须保留(功能不丢)")
|
||||||
|
require.True(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// computeFinalAnthropicBeta — 关键路径
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func newTestGatewayServiceForBeta(injectBetaForAPIKey bool) *GatewayService {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.InjectBetaForAPIKey = injectBetaForAPIKey
|
||||||
|
return &GatewayService{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthMimic_NonHaiku_IncludesContextManagement(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", http.Header{}, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"OAuth mimic non-haiku 必须注入完整 CC mimicry beta,含 context-management-2025-06-27")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaOAuth))
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaClaudeCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthMimic_Haiku_ExcludesContextManagement(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"OAuth mimic haiku 仅注入 oauth + interleaved-thinking,不含 context-management")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaOAuth))
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaInterleavedThinking))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthMimic_IgnoresClientBeta(t *testing.T) {
|
||||||
|
// mimic 路径下原代码白名单透传被跳过,client beta 应被忽略
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "custom-experimental-beta")
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, strings.Contains(final, "custom-experimental-beta"),
|
||||||
|
"mimic 路径必须忽略客户端 anthropic-beta header")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthTransparent_NonHaiku_PreservesClientContextManagement(t *testing.T) {
|
||||||
|
// 真 CC 客户端透传:客户端 header 中的 context-management beta 必须保留
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,context-management-2025-06-27")
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthTransparent_Haiku_RealCCPreservesContextManagement(t *testing.T) {
|
||||||
|
// haiku 透传 + 客户端带 context-management beta → 必须保留
|
||||||
|
// (能力维度核心场景:避免 model-name 误删客户端透传的功能 beta)
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,context-management-2025-06-27,interleaved-thinking-2025-05-14")
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", false, "claude-haiku-4-5", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"真 CC + haiku + 客户端带 context-management beta → 透传必须保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_APIKey_PassesClientBetaThroughDropSet(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "oauth-2025-04-20,custom-beta")
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("apikey", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, "oauth-2025-04-20"))
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, "custom-beta"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalAnthropicBeta_APIKey_NoClientBetaInjectOff_ShouldNotSet(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("apikey", false, "claude-sonnet-4-6", http.Header{}, []byte(`{}`), nil)
|
||||||
|
require.False(t, ok, "API-key + 客户端未传 + InjectBetaForAPIKey 关 → 不应主动设置 anthropic-beta")
|
||||||
|
require.Equal(t, "", final)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// computeFinalCountTokensAnthropicBeta
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestComputeFinalCountTokensAnthropicBeta_OAuthMimic_AlwaysIncludesContextManagement(t *testing.T) {
|
||||||
|
// count_tokens 路径下 mimic 不按 haiku 排除:始终注入完整 mimicry beta
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", true, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"count_tokens + mimic 即使 haiku 也注入 context-management beta(与 messages 不同)")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
|
||||||
|
"count_tokens 路径必须含 token-counting beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重构等价性回归:
|
||||||
|
// 原 main buildCountTokensRequest 在 count_tokens mimic 分支上不跳过白名单透传
|
||||||
|
// (与 messages mimic 不同),incomingBeta 取自客户端透传。重构后必须从 clientHeaders
|
||||||
|
// 拿同一个值并 merge,否则会丢失客户端 beta。
|
||||||
|
func TestComputeFinalCountTokensAnthropicBeta_OAuthMimic_PreservesClientBeta(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "custom-experimental-beta,context-1m-2025-08-07")
|
||||||
|
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", true, "claude-haiku-4-5", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, "custom-experimental-beta"),
|
||||||
|
"count_tokens mimic 不同于 messages mimic:原代码会保留客户端透传的 beta")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, "context-1m-2025-08-07"),
|
||||||
|
"客户端透传的其他 beta token 同样需要保留")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"同时 FullClaudeCodeMimicryBetas 不打折扣")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
|
||||||
|
"同时补齐 token-counting beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
// messages mimic 路径反向验证:原代码会跳过白名单透传,
|
||||||
|
// 客户端 beta 不会进入 mimic 计算。重构后 messages computeFinalAnthropicBeta
|
||||||
|
// mimic 分支依然不该使用 clientBeta。
|
||||||
|
func TestComputeFinalAnthropicBeta_OAuthMimic_IgnoresClientBetaExplicit(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "custom-experimental-beta")
|
||||||
|
final, ok := s.computeFinalAnthropicBeta("oauth", true, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, anthropicBetaTokensContains(final, "custom-experimental-beta"),
|
||||||
|
"messages mimic 原代码跳过白名单透传 → 客户端 beta 不进入计算。"+
|
||||||
|
"与 count_tokens mimic 是不同的设计,不能合并为同一函数。")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalCountTokensAnthropicBeta_OAuthTransparent_NoClientBetaInjectsDefault(t *testing.T) {
|
||||||
|
// 真 CC 客户端透传 + 客户端未传 anthropic-beta → 用 CountTokensBetaHeader 兜底
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", false, "claude-haiku-4-5", http.Header{}, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, claude.CountTokensBetaHeader, final)
|
||||||
|
// CountTokensBetaHeader 不含 context-management beta
|
||||||
|
require.False(t, anthropicBetaTokensContains(final, claude.BetaContextManagement))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeFinalCountTokensAnthropicBeta_OAuthTransparent_AppendsBetaTokenCounting(t *testing.T) {
|
||||||
|
s := newTestGatewayServiceForBeta(false)
|
||||||
|
hdr := http.Header{}
|
||||||
|
hdr.Set("anthropic-beta", "oauth-2025-04-20,context-management-2025-06-27")
|
||||||
|
final, ok := s.computeFinalCountTokensAnthropicBeta("oauth", false, "claude-sonnet-4-6", hdr, []byte(`{}`), nil)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaTokenCounting),
|
||||||
|
"客户端未带 token-counting beta 时必须补齐")
|
||||||
|
require.True(t, anthropicBetaTokensContains(final, claude.BetaContextManagement),
|
||||||
|
"客户端带的 context-management beta 必须保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// normalizeClaudeOAuthRequestBody — 回归:context_management 补齐恢复原行为
|
||||||
|
// ============================================================================
|
||||||
|
//
|
||||||
|
// 重构后该函数不再按 model 名短路:thinking=enabled/adaptive 时补齐 context_management,
|
||||||
|
// 与 model 无关。strip 责任移交 sanitizeAnthropicBodyForBetaTokens(在
|
||||||
|
// buildUpstreamRequest 层按最终 beta header 执行)。
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_InjectsContextManagement_ThinkingEnabled(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-6","thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
|
||||||
|
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-sonnet-4-6", claudeOAuthNormalizeOptions{})
|
||||||
|
require.True(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
require.Equal(t, "clear_thinking_20251015",
|
||||||
|
gjson.GetBytes(out, "context_management.edits.0.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_InjectsContextManagement_ThinkingAdaptive(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-opus-4-7","thinking":{"type":"adaptive"},"messages":[]}`)
|
||||||
|
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-opus-4-7", claudeOAuthNormalizeOptions{})
|
||||||
|
require.True(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_HaikuStillInjects_StripDeferredToSanitize(t *testing.T) {
|
||||||
|
// haiku + thinking=enabled:normalize 阶段仍按 CLI mimicry 行为补齐字段;
|
||||||
|
// strip 由 buildUpstreamRequest 层的 sanitize 兜底(如果 final beta 不含 token)。
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
|
||||||
|
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-haiku-4-5", claudeOAuthNormalizeOptions{})
|
||||||
|
require.True(t, gjson.GetBytes(out, "context_management").Exists(),
|
||||||
|
"normalize 不再按 model 名短路;strip 责任移交 sanitize 层")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_PreservesClientContextManagement(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-opus-4-7","context_management":{"edits":[{"type":"custom_strategy"}]},"thinking":{"type":"enabled","budget_tokens":1000},"messages":[]}`)
|
||||||
|
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-opus-4-7", claudeOAuthNormalizeOptions{})
|
||||||
|
require.Equal(t, "custom_strategy",
|
||||||
|
gjson.GetBytes(out, "context_management.edits.0.type").String(),
|
||||||
|
"客户端透传的 context_management 内容必须原样保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeOAuthRequestBody_NoThinking_NoInject(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-6","messages":[]}`)
|
||||||
|
out, _ := normalizeClaudeOAuthRequestBody(body, "claude-sonnet-4-6", claudeOAuthNormalizeOptions{})
|
||||||
|
require.False(t, gjson.GetBytes(out, "context_management").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// passthrough 集成测试:buildUpstreamRequest-
|
||||||
|
// AnthropicAPIKeyPassthrough 与 buildCountTokensRequestAnthropicAPIKeyPassthrough
|
||||||
|
// 路径上 sanitize 是否生效。
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// passthrough 集成测试不设 base_url,避开 validateUpstreamBaseURL 对 cfg.Security 的依赖。
|
||||||
|
// targetURL 会走默认 claudeAPIURL,sanitize 逻辑与 baseURL 是否存在无关。
|
||||||
|
func newAnthropicAPIKeyPassthroughAccountForBetaTest() *Account {
|
||||||
|
return &Account{
|
||||||
|
ID: 501,
|
||||||
|
Name: "anthropic-apikey-passthrough-ctxmgmt-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readUpstreamBodyForTest(t *testing.T, req *http.Request) []byte {
|
||||||
|
t.Helper()
|
||||||
|
require.NotNil(t, req.Body)
|
||||||
|
b, err := io.ReadAll(req.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequestAnthropicAPIKeyPassthrough_StripsContextManagementWhenClientHeaderMissingBeta(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
// 客户端仅带 oauth beta,不带 context-management-2025-06-27
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20")
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||||
|
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
|
||||||
|
"API-key passthrough + 客户端未带 context-management beta → strip body 字段")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequestAnthropicAPIKeyPassthrough_PreservesContextManagementWhenClientHeaderHasBeta(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,context-management-2025-06-27")
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||||
|
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
|
||||||
|
"API-key passthrough + 客户端带 context-management beta → 字段保留(不过度删除)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCountTokensRequestAnthropicAPIKeyPassthrough_StripsContextManagementWhenClientHeaderMissingBeta(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,token-counting-2024-11-01")
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||||
|
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
|
||||||
|
"count_tokens passthrough + 客户端未带 context-management beta → strip")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// 集成测试:buildUpstreamRequest
|
||||||
|
// 全路径验证上游 outgoing body 与 anthropic-beta header 严格对称。
|
||||||
|
// 这个测试能挡住未来某人忘调 sanitize / 将 sanitize 挪到 CCH 之后 等 regression。
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequest_OAuthMimicHaiku_StripsContextManagementEndToEnd(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
account := &Account{ID: 401, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-tok"},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
// haiku + mimic CC → final beta = HaikuBetaHeader(不含 context-management)→
|
||||||
|
// body 必须 strip。
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"oauth-tok", "oauth", "claude-haiku-4-5", false, true, // mimicClaudeCode=true
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
|
||||||
|
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"OAuth mimic + haiku 端到端:outgoing body 不应含 context_management")
|
||||||
|
require.False(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
|
||||||
|
"对称约束:outgoing anthropic-beta header 也不带 context-management beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequest_OAuthMimicNonHaiku_PreservesContextManagementEndToEnd(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
account := &Account{ID: 402, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-tok"},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
// sonnet + mimic CC → final beta = FullClaudeCodeMimicryBetas(含 context-management)→
|
||||||
|
// body 保留。
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"oauth-tok", "oauth", "claude-sonnet-4-6", false, true,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
|
||||||
|
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"OAuth mimic + non-haiku:outgoing body 必须保留 context_management。")
|
||||||
|
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
|
||||||
|
"对称约束:outgoing anthropic-beta header 同时含 context-management beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequest_OAuthTransparentHaikuWithRealCCBeta_PreservesField(t *testing.T) {
|
||||||
|
// 端到端验证:真 CC 客户端 + haiku + 客户端 header 带 context-management beta
|
||||||
|
// → final beta 透传 → 不应该过度删除 body 字段
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta",
|
||||||
|
"claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27")
|
||||||
|
|
||||||
|
account := &Account{ID: 403, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-tok"},
|
||||||
|
Status: StatusActive, Schedulable: true,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"oauth-tok", "oauth", "claude-haiku-4-5", false, false, // mimicClaudeCode=false(真 CC)
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
|
||||||
|
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
|
||||||
|
"真 CC 透传路径:客户端 header 中的 context-management beta 必须保留")
|
||||||
|
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"回归保护:真 CC + haiku + 客户端带 beta token 时,clear_thinking_20251015 功能不能静默失效")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CCH 顺序语义测试:sanitize 必须在 signBillingHeaderCCH 之前,
|
||||||
|
// 否则签名的 hash 与最终发送的 body 不一致,被 Anthropic 判 third-party。
|
||||||
|
//
|
||||||
|
// 该测试不走 buildUpstreamRequest 完整路径(需要 mock SettingService 成本高),
|
||||||
|
// 而是直接验证两个顺序产生的 cch 不同,证明二者不可交换。
|
||||||
|
// 测试名本身是语义约束的文档化 marker。
|
||||||
|
func TestSanitizeMustBeBeforeCCHSigning_HashConsistency(t *testing.T) {
|
||||||
|
// 构造 body:含 context_management + cch=00000 占位符
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92; cch=00000;"}],"messages":[]}`)
|
||||||
|
|
||||||
|
// 最终发送场景:final beta 不含 context-management beta → sanitize 会 strip
|
||||||
|
finalBeta := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||||
|
|
||||||
|
extractCCH := func(t *testing.T, b []byte) string {
|
||||||
|
t.Helper()
|
||||||
|
m := regexp.MustCompile(`\bcch=([0-9a-fA-F]{5})\b`).FindSubmatch(b)
|
||||||
|
require.NotNil(t, m, "body 里找不到 cch=<5hex> :%s", string(b))
|
||||||
|
return string(m[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 正确顺序:sanitize → signBillingHeaderCCH ===
|
||||||
|
// 1. strip context_management
|
||||||
|
sanitizedFirst, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBeta)
|
||||||
|
require.True(t, changed)
|
||||||
|
require.False(t, gjson.GetBytes(sanitizedFirst, "context_management").Exists())
|
||||||
|
// 2. 基于“strip 后的 body”算 hash
|
||||||
|
correctFinal := signBillingHeaderCCH(sanitizedFirst)
|
||||||
|
correctCCH := extractCCH(t, correctFinal)
|
||||||
|
require.NotEqual(t, "00000", correctCCH, "placeholder 应被替换")
|
||||||
|
|
||||||
|
// === 错误顺序:signBillingHeaderCCH → sanitize(未来 regression 场景)===
|
||||||
|
// 1. 先基于“含 context_management 的 body”算 hash → cch=H_with
|
||||||
|
signedFirst := signBillingHeaderCCH(body)
|
||||||
|
wrongCCH := extractCCH(t, signedFirst)
|
||||||
|
require.NotEqual(t, "00000", wrongCCH)
|
||||||
|
// 2. 后 strip context_management → body 变化但 cch 仍是 H_with
|
||||||
|
wrongFinal, _ := sanitizeAnthropicBodyForBetaTokens(signedFirst, finalBeta)
|
||||||
|
wrongFinalCCH := extractCCH(t, wrongFinal)
|
||||||
|
|
||||||
|
// === 关键断言 ===
|
||||||
|
// 上游验证逻辑:将 outgoing body 的 cch 还原为 00000、重算 hash、与 cch 字段比较。
|
||||||
|
// 模拟上游验证:用发送 body 算出“期望的 cch”,与发送 body 里的 cch 字段比。
|
||||||
|
recomputeExpected := func(b []byte, currentCCH string) string {
|
||||||
|
t.Helper()
|
||||||
|
// 把 cch=<currentCCH> 还原为 cch=00000
|
||||||
|
re := regexp.MustCompile(`(\bcch=)` + currentCCH + `(\b)`)
|
||||||
|
restored := re.ReplaceAll(b, []byte("${1}00000${2}"))
|
||||||
|
return extractCCH(t, signBillingHeaderCCH(restored))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 正确顺序:发送 body 的 cch == 重算 hash → 上游验证过
|
||||||
|
require.Equal(t, correctCCH, recomputeExpected(correctFinal, correctCCH),
|
||||||
|
"正确顺序:final body 里的 cch 与重算 hash 一致 → 上游验证通过")
|
||||||
|
|
||||||
|
// 错误顺序:发送 body 的 cch 是“含 ctx 算的”,但最终 body 不含 ctx → 重算 hash 不同
|
||||||
|
require.NotEqual(t, wrongFinalCCH, recomputeExpected(wrongFinal, wrongFinalCCH),
|
||||||
|
"错误顺序:final body 里的 cch 是基于含 ctx 的 body 算的,"+
|
||||||
|
"但发送 body 已 strip ctx → 上游重算 hash 与 cch 不一致 → 被判 third-party。"+
|
||||||
|
"这是 buildUpstreamRequest / buildCountTokensRequest 里 sanitize 必须在 "+
|
||||||
|
"signBillingHeaderCCH 之前的原因。")
|
||||||
|
}
|
||||||
|
|
||||||
|
// count_tokens 主路径 E2E 集成测试
|
||||||
|
func TestBuildCountTokensRequest_OAuthMimicHaiku_PreservesContextManagementEndToEnd(t *testing.T) {
|
||||||
|
// count_tokens 路径下 mimic 不按 haiku 排除,始终注入 BetaContextManagement
|
||||||
|
// → sanitize 看到最终 beta header 含 context-management beta → 字段保留。
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
account := &Account{ID: 411, Platform: PlatformAnthropic, Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-tok"},
|
||||||
|
Status: StatusActive, Schedulable: true,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildCountTokensRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"oauth-tok", "oauth", "claude-haiku-4-5", true, // mimicClaudeCode=true
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
outBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
|
|
||||||
|
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaContextManagement),
|
||||||
|
"count_tokens mimic 始终注入 context-management beta")
|
||||||
|
require.True(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"对称约束:final beta 含 token 时 body 字段保留")
|
||||||
|
require.True(t, anthropicBetaTokensContains(outBeta, claude.BetaTokenCounting),
|
||||||
|
"count_tokens 路径必须含 token-counting beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCountTokensRequest_APIKeyHaiku_StripsContextManagementEndToEnd(t *testing.T) {
|
||||||
|
// API-key + haiku + 客户端 header 不带 context-management beta → final beta 不含 → strip
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||||
|
|
||||||
|
account := &Account{ID: 412, Platform: PlatformAnthropic, Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{"api_key": "sk-ant-xxx"},
|
||||||
|
Status: StatusActive, Schedulable: true,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildCountTokensRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"sk-ant-xxx", "apikey", "claude-haiku-4-5", false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"count_tokens API-key + 客户端未带 beta token → body strip")
|
||||||
|
}
|
||||||
|
|
||||||
|
// count_tokens passthrough preserve 测试
|
||||||
|
func TestBuildCountTokensRequestAnthropicAPIKeyPassthrough_PreservesContextManagementWhenClientHeaderHasBeta(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "oauth-2025-04-20,context-management-2025-06-27,token-counting-2024-11-01")
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[{"type":"clear_thinking_20251015"}]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||||
|
context.Background(), c, newAnthropicAPIKeyPassthroughAccountForBetaTest(), body, "token",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, gjson.GetBytes(readUpstreamBodyForTest(t, req), "context_management").Exists(),
|
||||||
|
"count_tokens passthrough + 客户端带 context-management beta → 字段保留")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpstreamRequest_APIKeyHaikuWithContextManagement_StripsField(t *testing.T) {
|
||||||
|
// API-key + haiku + body 带 context_management + 客户端 header 未带 context-management beta
|
||||||
|
// → final beta 不含 → body 字段被 strip
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
|
||||||
|
|
||||||
|
account := &Account{ID: 404, Platform: PlatformAnthropic, Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{"api_key": "sk-ant-xxx"},
|
||||||
|
Status: StatusActive, Schedulable: true,
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-haiku-4-5","context_management":{"edits":[]},"messages":[]}`)
|
||||||
|
svc := &GatewayService{cfg: &config.Config{}}
|
||||||
|
req, err := svc.buildUpstreamRequest(
|
||||||
|
context.Background(), c, account, body,
|
||||||
|
"sk-ant-xxx", "apikey", "claude-haiku-4-5", false, false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outBody := readUpstreamBodyForTest(t, req)
|
||||||
|
require.False(t, gjson.GetBytes(outBody, "context_management").Exists(),
|
||||||
|
"API-key + haiku + 客户端未带 beta token → body 字段必须被 strip")
|
||||||
|
}
|
||||||
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@ -685,6 +686,69 @@ func removeThinkingDependentContextStrategies(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anthropicBetaContextManagementToken 是 context_management 字段受的 beta token。
|
||||||
|
// 与 claude.BetaContextManagement 保持一致;在本文件本地定义以避免震荡
|
||||||
|
// claude package 的该常量含义。
|
||||||
|
const anthropicBetaContextManagementToken = "context-management-2025-06-27"
|
||||||
|
|
||||||
|
// sanitizeAnthropicBodyForBetaTokens 是对 Anthropic 直连路径上 body↔beta header
|
||||||
|
// **能力维度**对称约束的统一实现,与 Bedrock 路径的
|
||||||
|
// `sanitizeBedrockFieldsForBetaTokens` 对称。
|
||||||
|
//
|
||||||
|
// 问题场景:
|
||||||
|
// - context_management 是 Claude Code CLI 2.1.87+ 默认携带的 beta 字段
|
||||||
|
// (含 clear_thinking_20251015 等清理策略)
|
||||||
|
// - 其被 Anthropic 上游接受的前提是 anthropic-beta header 含
|
||||||
|
// `context-management-2025-06-27`
|
||||||
|
// - 若两侧不一致上游 Pydantic schema 拒收:
|
||||||
|
// "context_management: Extra inputs are not permitted"
|
||||||
|
//
|
||||||
|
// 本函数按最终发送的 anthropic-beta header 决定是否保留 body 中的
|
||||||
|
// context_management 字段:缺 beta token → strip。这将限制完全建立在
|
||||||
|
// "能力维度" 上,与 model 名 / token type / mimicry 子路径无关。
|
||||||
|
//
|
||||||
|
// 调用约束:必须在 CCH 签名之前调用,否则签名 hash 与最终 body
|
||||||
|
// 不一致,上游会以 third-party 拒收。
|
||||||
|
//
|
||||||
|
// 返回 (sanitized, changed):changed 表示是否发生实际删除,供调用方决定
|
||||||
|
// 是否重用原 body 引用。
|
||||||
|
func sanitizeAnthropicBodyForBetaTokens(body []byte, anthropicBetaHeader string) ([]byte, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(body, "context_management").Exists() {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
if anthropicBetaTokensContains(anthropicBetaHeader, anthropicBetaContextManagementToken) {
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
if b, err := sjson.DeleteBytes(body, "context_management"); err == nil {
|
||||||
|
return b, true
|
||||||
|
} else {
|
||||||
|
// 不应发生:gjson 刚验证过字段存在 + body 是合法 JSON。如果 sjson 仍报错,
|
||||||
|
// 调用方会拿到 (body, false),但此前 computeFinalAnthropicBeta 已按“strip 后”
|
||||||
|
// 计算了 finalBeta——两侧会不一致。记录 warning 最小限度提醒运维。
|
||||||
|
logger.LegacyPrintf("service.gateway",
|
||||||
|
"[CtxMgmtSanitize] sjson.DeleteBytes failed unexpectedly: %v (body len=%d). "+
|
||||||
|
"body and final anthropic-beta header may be out of sync.", err, len(body))
|
||||||
|
}
|
||||||
|
return body, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// anthropicBetaTokensContains 检测逗号分隔的 anthropic-beta header 是否含指定 token。
|
||||||
|
// 宋体空格宽容;区分大小写(Anthropic beta token 始终是小写)。
|
||||||
|
func anthropicBetaTokensContains(header, token string) bool {
|
||||||
|
if header == "" || token == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, part := range strings.Split(header, ",") {
|
||||||
|
if strings.TrimSpace(part) == token {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||||
// signature/thought_signature validation issues involving tool blocks.
|
// signature/thought_signature validation issues involving tool blocks.
|
||||||
//
|
//
|
||||||
|
|||||||
@ -26,12 +26,9 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claudemask"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/telemetry"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
@ -282,10 +279,6 @@ func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account,
|
|||||||
interesting := []string{
|
interesting := []string{
|
||||||
"user-agent",
|
"user-agent",
|
||||||
"x-app",
|
"x-app",
|
||||||
"x-client-app",
|
|
||||||
"x-claude-remote-session-id",
|
|
||||||
"x-claude-remote-container-id",
|
|
||||||
"x-anthropic-additional-protection",
|
|
||||||
"anthropic-dangerous-direct-browser-access",
|
"anthropic-dangerous-direct-browser-access",
|
||||||
"anthropic-version",
|
"anthropic-version",
|
||||||
"anthropic-beta",
|
"anthropic-beta",
|
||||||
@ -404,10 +397,6 @@ var allowedHeaders = map[string]bool{
|
|||||||
"accept-encoding": true,
|
"accept-encoding": true,
|
||||||
"x-claude-code-session-id": true,
|
"x-claude-code-session-id": true,
|
||||||
"x-client-request-id": true,
|
"x-client-request-id": true,
|
||||||
"x-client-app": true,
|
|
||||||
"x-claude-remote-session-id": true,
|
|
||||||
"x-claude-remote-container-id": true,
|
|
||||||
"x-anthropic-additional-protection": true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GatewayCache 定义网关服务的缓存操作接口。
|
// GatewayCache 定义网关服务的缓存操作接口。
|
||||||
@ -428,16 +417,6 @@ type GatewayCache interface {
|
|||||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||||
|
|
||||||
// GetCascadeID 获取 Windsurf Cascade 会话 ID(用于 LS 多轮复用)
|
|
||||||
// Get the Windsurf Cascade ID bound to a chat session for multi-turn LS reuse.
|
|
||||||
GetCascadeID(ctx context.Context, key string) (string, error)
|
|
||||||
// SetCascadeID 写入 Cascade 会话 ID
|
|
||||||
// Persist the Cascade session ID with the given TTL.
|
|
||||||
SetCascadeID(ctx context.Context, key string, cascadeID string, ttl time.Duration) error
|
|
||||||
// DeleteCascadeID 失效 Cascade 会话 ID(panel-not-found / 错误时调用)
|
|
||||||
// Invalidate the cached Cascade session ID on panel-not-found or upstream error.
|
|
||||||
DeleteCascadeID(ctx context.Context, key string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||||
@ -600,8 +579,7 @@ type GatewayService struct {
|
|||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
|
|
||||||
userGroupRateResolver *userGroupRateResolver
|
userGroupRateResolver *userGroupRateResolver
|
||||||
userGroupRateCache *gocache.Cache
|
userGroupRateCache *gocache.Cache
|
||||||
userGroupRateSF singleflight.Group
|
userGroupRateSF singleflight.Group
|
||||||
@ -647,7 +625,6 @@ func NewGatewayService(
|
|||||||
channelService *ChannelService,
|
channelService *ChannelService,
|
||||||
resolver *ModelPricingResolver,
|
resolver *ModelPricingResolver,
|
||||||
balanceNotifyService *BalanceNotifyService,
|
balanceNotifyService *BalanceNotifyService,
|
||||||
rpmTokenBucketSvc *RPMTokenBucketService,
|
|
||||||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||||
@ -675,7 +652,6 @@ func NewGatewayService(
|
|||||||
claudeTokenProvider: claudeTokenProvider,
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
sessionLimitCache: sessionLimitCache,
|
sessionLimitCache: sessionLimitCache,
|
||||||
rpmCache: rpmCache,
|
rpmCache: rpmCache,
|
||||||
rpmTokenBucket: rpmTokenBucketSvc,
|
|
||||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||||
@ -1179,6 +1155,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
// context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动
|
// context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动
|
||||||
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
|
// 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
|
||||||
// 客户端显式传了就透传;否则按 CLI 行为补齐。
|
// 客户端显式传了就透传;否则按 CLI 行为补齐。
|
||||||
|
//
|
||||||
|
// 注:本函数不按 model 名决定是否保留 context_management。“最终 beta
|
||||||
|
// header 不含 context-management-2025-06-27 时 strip 字段”的能力维度
|
||||||
|
// 对称约束由 sanitizeAnthropicBodyForBetaTokens 在 buildUpstreamRequest /
|
||||||
|
// buildCountTokensRequest 层统一执行,与 Bedrock 路径的
|
||||||
|
// sanitizeBedrockFieldsForBetaTokens 对称。
|
||||||
if !gjson.GetBytes(out, "context_management").Exists() {
|
if !gjson.GetBytes(out, "context_management").Exists() {
|
||||||
thinkingType := gjson.GetBytes(out, "thinking.type").String()
|
thinkingType := gjson.GetBytes(out, "thinking.type").String()
|
||||||
if thinkingType == "enabled" || thinkingType == "adaptive" {
|
if thinkingType == "enabled" || thinkingType == "adaptive" {
|
||||||
@ -1431,9 +1413,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||||
if platform == PlatformAnthropic && s.enableTierFallbackChain() {
|
|
||||||
return s.selectAccountWithTierFallback(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
|
||||||
}
|
|
||||||
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -2382,14 +2361,6 @@ func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, gr
|
|||||||
return len(accounts) == 1
|
return len(accounts) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) IsSingleWindsurfAccountGroup(ctx context.Context, groupID *int64) bool {
|
|
||||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformWindsurf, true)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return len(accounts) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return false
|
return false
|
||||||
@ -2717,15 +2688,6 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
|
|
||||||
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
|
|
||||||
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
|
||||||
if s.rpmTokenBucket == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||||
@ -3754,12 +3716,6 @@ func summarizeSelectionFailureStats(stats selectionFailureStats) string {
|
|||||||
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
|
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
|
||||||
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
|
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
|
||||||
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
|
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
|
||||||
if account.Platform == PlatformWindsurf {
|
|
||||||
if strings.TrimSpace(requestedModel) == "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return windsurf.ResolveModel(requestedModel) != ""
|
|
||||||
}
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
if strings.TrimSpace(requestedModel) == "" {
|
if strings.TrimSpace(requestedModel) == "" {
|
||||||
return true
|
return true
|
||||||
@ -3784,12 +3740,6 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex
|
|||||||
|
|
||||||
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台)
|
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台)
|
||||||
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||||
if account.Platform == PlatformWindsurf {
|
|
||||||
if strings.TrimSpace(requestedModel) == "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return windsurf.ResolveModel(requestedModel) != ""
|
|
||||||
}
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
if strings.TrimSpace(requestedModel) == "" {
|
if strings.TrimSpace(requestedModel) == "" {
|
||||||
return true
|
return true
|
||||||
@ -4148,12 +4098,10 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
|
|||||||
// 模型仍通过 messages 接收完整指令,保留客户端功能
|
// 模型仍通过 messages 接收完整指令,保留客户端功能
|
||||||
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
|
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||||
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
|
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
|
||||||
// 规范化 env 字段(Platform/Shell/OS/路径),防止真实机器信息被 Anthropic 用作跨账号关联信号。
|
|
||||||
normalizedSystemText := NormalizeSystemPromptEnv(originalSystemText)
|
|
||||||
instrMsg, err1 := json.Marshal(map[string]any{
|
instrMsg, err1 := json.Marshal(map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": []map[string]any{
|
"content": []map[string]any{
|
||||||
{"type": "text", "text": "[System Instructions]\n" + normalizedSystemText},
|
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
ackMsg, err2 := json.Marshal(map[string]any{
|
ackMsg, err2 := json.Marshal(map[string]any{
|
||||||
@ -4599,44 +4547,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
|
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
|
||||||
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account)
|
||||||
|
|
||||||
// Bootstrap 预热:模拟真实 CLI 启动时的 GET /api/claude_cli/bootstrap 调用
|
|
||||||
// 真实 CLI 在首次 messages 请求前 fire-and-forget 调用此端点。
|
|
||||||
if tokenType == "oauth" && token != "" {
|
|
||||||
instanceSalt := ""
|
|
||||||
useSharedUpstream := proxyURL != "" || tlsProfile != nil
|
|
||||||
if s.cfg != nil {
|
|
||||||
instanceSalt = s.cfg.Gateway.InstanceSalt
|
|
||||||
useSharedUpstream = useSharedUpstream || s.cfg.Gateway.TLSFingerprint.Enabled
|
|
||||||
}
|
|
||||||
TriggerBootstrapIfNeeded(account.ID, token, &BackgroundRequestOptions{
|
|
||||||
ProxyURL: proxyURL,
|
|
||||||
HTTPUpstream: s.httpUpstream,
|
|
||||||
TLSProfile: tlsProfile,
|
|
||||||
InstanceSalt: instanceSalt,
|
|
||||||
UseSharedUpstream: useSharedUpstream,
|
|
||||||
})
|
|
||||||
// OTEL telemetry: emit pre-request events (tengu_started, tengu_api_query etc.)
|
|
||||||
go telemetry.EmitPreRequest(
|
|
||||||
fmt.Sprintf("%d", account.ID),
|
|
||||||
token,
|
|
||||||
token,
|
|
||||||
reqModel,
|
|
||||||
getHeaderRaw(c.Request.Header, "anthropic-beta"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 调试日志:记录即将转发的账号信息
|
// 调试日志:记录即将转发的账号信息
|
||||||
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||||
account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL)
|
account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL)
|
||||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||||
body = StripEmptyTextBlocks(body)
|
body = StripEmptyTextBlocks(body)
|
||||||
|
|
||||||
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
|
|
||||||
if account.IsContextCompressionEnabled() {
|
|
||||||
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
|
|
||||||
body = compressMessagesInBody(body, maxTok)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 重试循环
|
// 重试循环
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
@ -5032,18 +4948,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 处理正常响应
|
// 处理正常响应
|
||||||
|
|
||||||
// OTEL telemetry: emit post-request events (fire-and-forget)
|
|
||||||
if tokenType == "oauth" && token != "" {
|
|
||||||
go telemetry.EmitPostRequest(
|
|
||||||
fmt.Sprintf("%d", account.ID),
|
|
||||||
token,
|
|
||||||
token,
|
|
||||||
reqModel,
|
|
||||||
getHeaderRaw(c.Request.Header, "anthropic-beta"),
|
|
||||||
resp.StatusCode,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
||||||
if parsed.OnUpstreamAccepted != nil {
|
if parsed.OnUpstreamAccepted != nil {
|
||||||
parsed.OnUpstreamAccepted()
|
parsed.OnUpstreamAccepted()
|
||||||
@ -5350,6 +5254,17 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
|||||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 能力维度 body sanitize:透传路径上 anthropic-beta header 原样透传客户端值,
|
||||||
|
// 依此决定是否保留 body 中的 context_management。避免“客户端 body 带字段但
|
||||||
|
// header 忘记带 beta token”的客户端 bug 在透传场景下让上游 400。
|
||||||
|
clientBeta := ""
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
clientBeta = getHeaderRaw(c.Request.Header, "anthropic-beta")
|
||||||
|
}
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
|
||||||
|
body = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -6175,9 +6090,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
|
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
|
||||||
var fingerprint *Fingerprint
|
var fingerprint *Fingerprint
|
||||||
enableFP, enableMPT, _ := true, false, false
|
enableFP, enableMPT, enableCCH := true, false, false
|
||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
enableFP, enableMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
|
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||||
}
|
}
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && s.identityService != nil {
|
||||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||||
@ -6208,9 +6123,33 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
if fingerprint != nil {
|
if fingerprint != nil {
|
||||||
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
|
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
|
||||||
}
|
}
|
||||||
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)。
|
|
||||||
// 无占位符时函数为 no-op,故无需 enableCCH gate — 占位符存在即意味着必须签名。
|
// === 计算最终 anthropic-beta header(先于 body sanitize 与 CCH 签名)===
|
||||||
body = signBillingHeaderCCH(body)
|
//
|
||||||
|
// 顺序约束:
|
||||||
|
// 1) 算 finalBeta(纯函数,不依赖 req.Header;mimicry 路径会忽略客户端 beta,
|
||||||
|
// 与原“OAuth + mimicClaudeCode 跳过白名单透传”行为对齐)
|
||||||
|
// 2) 按 finalBeta 做能力维度 body sanitize(如 context-management beta 缺失 →
|
||||||
|
// strip body.context_management,与 Bedrock 路径对称)
|
||||||
|
// 3) CCH 签名(必须使用 strip 后的 body,否则 hash 与最终 body 不一致 →
|
||||||
|
// 被 Anthropic 判 third-party)
|
||||||
|
// 4) NewRequest(body 至此最终敲定)
|
||||||
|
// 5) 透传白名单 / fingerprint / mimic header / 写入 finalBeta
|
||||||
|
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
|
||||||
|
effectiveDropSet := mergeDropSets(policyFilterSet)
|
||||||
|
finalBetaHeader, finalBetaShouldSet := s.computeFinalAnthropicBeta(
|
||||||
|
tokenType, mimicClaudeCode, modelID, clientHeaders, body, effectiveDropSet,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 能力维度 body sanitize:与最终 anthropic-beta header 对称
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBetaHeader); changed {
|
||||||
|
body = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
|
||||||
|
if enableCCH {
|
||||||
|
body = signBillingHeaderCCH(body)
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -6257,57 +6196,27 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
applyClaudeOAuthHeaderDefaults(req)
|
applyClaudeOAuthHeaderDefaults(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
// OAuth + mimic Claude Code:强制注入 CLI 指纹相关 header
|
||||||
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
|
// (user-agent/x-stainless-*/x-app/Accept/x-stainless-helper-method/x-client-request-id)
|
||||||
effectiveDropSet := mergeDropSets(policyFilterSet)
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
|
applyClaudeCodeMimicHeaders(req, reqStream)
|
||||||
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
|
||||||
if tokenType == "oauth" {
|
|
||||||
if mimicClaudeCode {
|
|
||||||
// 非 Claude Code 客户端:按 opencode 的策略处理:
|
|
||||||
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
|
|
||||||
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
|
|
||||||
applyClaudeCodeMimicHeaders(req, reqStream)
|
|
||||||
|
|
||||||
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
|
||||||
// Claude Code OAuth credentials are scoped to Claude Code.
|
|
||||||
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
|
|
||||||
// this as a legitimate Claude Code request; without it, the request is
|
|
||||||
// rejected as third-party ("out of extra usage").
|
|
||||||
// Haiku models are exempt from third-party detection and don't need it.
|
|
||||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
|
||||||
if !strings.Contains(strings.ToLower(modelID), "haiku") {
|
|
||||||
requiredBetas = claude.FullClaudeCodeMimicryBetas()
|
|
||||||
}
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
|
|
||||||
} else {
|
|
||||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
|
||||||
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// API-key accounts: apply beta policy filter to strip controlled tokens
|
|
||||||
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet))
|
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
|
||||||
if requestNeedsBetaFeatures(body) {
|
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", beta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// X-Claude-Code-Session-Id 头处理:
|
// 写入最终 anthropic-beta header
|
||||||
// Claude Code 2.1.145 SDK 内强制设置该头(`"X-Claude-Code-Session-Id":y_()`)。
|
// 注:透传分支白名单可能写入了客户端 anthropic-beta,无条件 Del 一次再按 finalBeta
|
||||||
// 优先取 metadata.user_id 中的 sessionID;OAuth mimic 场景缺失时兜底 UUID,
|
// 决定是否 set,确保 dropSet 过滤后的结果一定覆盖客户端原始值。
|
||||||
// 避免上游基于该头缺失判定为第三方调用。
|
deleteHeaderAllForms(req.Header, "anthropic-beta")
|
||||||
ensureClaudeCodeSessionID(req, body, tokenType, mimicClaudeCode)
|
if finalBetaShouldSet {
|
||||||
|
setHeaderRaw(req.Header, "anthropic-beta", finalBetaHeader)
|
||||||
|
}
|
||||||
|
|
||||||
// x-client-request-id: 真实 CLI 每个请求生成新 UUID(仅 1P)。
|
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||||
|
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||||
|
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
|
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
|
||||||
@ -6320,25 +6229,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
"enable_mpt": strconv.FormatBool(enableMPT),
|
"enable_mpt": strconv.FormatBool(enableMPT),
|
||||||
})
|
})
|
||||||
|
|
||||||
// === Node.js 伪装验证与清理 ===
|
|
||||||
// 对于 OAuth 账号,验证请求是否正确伪装为 Node.js Claude Code 客户端
|
|
||||||
// 这是一个"猴子补丁",确保所有头都符合 Node.js 客户端标准
|
|
||||||
if tokenType == "oauth" && account.IsOAuth() {
|
|
||||||
// 1. 清理任何可能暴露 Go 客户端身份的头
|
|
||||||
if ua := req.Header.Get("User-Agent"); ua == "" || strings.Contains(ua, "Go-http-client") {
|
|
||||||
// User-Agent 缺失或包含 Go 指示,修复为 Node.js 格式
|
|
||||||
setHeaderRaw(req.Header, "User-Agent", claude.DefaultUserAgent())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 验证 Node.js 指纹完整性(用于调试日志)
|
|
||||||
if s.debugClaudeMimicEnabled() {
|
|
||||||
isValid, errors := claudemask.ValidateNodeEmulation(req)
|
|
||||||
if !isValid {
|
|
||||||
logger.LegacyPrintf("service.gateway", "⚠️ Node.js emulation validation failed: %v", errors)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always capture a compact fingerprint line for later error diagnostics.
|
// Always capture a compact fingerprint line for later error diagnostics.
|
||||||
// We only print it when needed (or when the explicit debug flag is enabled).
|
// We only print it when needed (or when the explicit debug flag is enabled).
|
||||||
if c != nil && tokenType == "oauth" {
|
if c != nil && tokenType == "oauth" {
|
||||||
@ -6349,7 +6239,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
||||||
@ -6365,6 +6254,16 @@ func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 能力维度 sanitize:Vertex 路径上 anthropic-beta header 原样透传客户端值
|
||||||
|
// (下面白名单跳过 anthropic-version 但保留 anthropic-beta),依此决定是否
|
||||||
|
// 保留 body 中的 context_management,与 Anthropic 直连 / Bedrock 路径对称。
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
clientBeta := getHeaderRaw(c.Request.Header, "anthropic-beta")
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(vertexBody, clientBeta); changed {
|
||||||
|
vertexBody = sanitized
|
||||||
|
}
|
||||||
|
}
|
||||||
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
|
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -6449,7 +6348,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
|
|||||||
return claude.HaikuBetaHeader
|
return claude.HaikuBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
return claude.GetOAuthBetaHeader(modelID)
|
return claude.DefaultBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestNeedsBetaFeatures(body []byte) bool {
|
func requestNeedsBetaFeatures(body []byte) bool {
|
||||||
@ -6466,7 +6365,10 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
|||||||
|
|
||||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||||
modelID := gjson.GetBytes(body, "model").String()
|
modelID := gjson.GetBytes(body, "model").String()
|
||||||
return claude.GetAPIKeyBetaHeader(modelID)
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
|
return claude.APIKeyHaikuBetaHeader
|
||||||
|
}
|
||||||
|
return claude.APIKeyBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
||||||
@ -6476,7 +6378,7 @@ func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
|||||||
if getHeaderRaw(req.Header, "Accept") == "" {
|
if getHeaderRaw(req.Header, "Accept") == "" {
|
||||||
setHeaderRaw(req.Header, "Accept", "application/json")
|
setHeaderRaw(req.Header, "Accept", "application/json")
|
||||||
}
|
}
|
||||||
for key, value := range claude.DefaultHeadersSnapshot() {
|
for key, value := range claude.DefaultHeaders {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -6530,6 +6432,121 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
|
|||||||
return strings.Join(out, ",")
|
return strings.Join(out, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// computeFinalAnthropicBeta 计算发往上游的最终 anthropic-beta header 值。
|
||||||
|
//
|
||||||
|
// 设计动机:将原本在 buildUpstreamRequest 内联在一起、依赖 req.Header 的
|
||||||
|
// anthropic-beta 计算逻辑抽成纯函数。这样调用方可以在 NewRequest 之前
|
||||||
|
// 就提前拿到最终 beta header,进而能按它对 body 做能力维度 sanitize 后再做
|
||||||
|
// CCH 签名——一举修复了以下之前由顺序依赖导致的能力维度 sanitize
|
||||||
|
// 无法部署的问题(签名与最终 body 不一致可以被判 third-party)。
|
||||||
|
//
|
||||||
|
// 返回 (value, shouldSet):
|
||||||
|
// - shouldSet=false 意为“不主动设置 anthropic-beta header”,与原代码“
|
||||||
|
// API-key 账号 + 客户端未传 anthropic-beta + InjectBetaForAPIKey 未开启或
|
||||||
|
// requestNeedsBetaFeatures=false”的行为对齐。
|
||||||
|
// - shouldSet=true 时 value 可能为空字符串(例如客户端透传的 beta 被 dropSet
|
||||||
|
// 全部过滤掉),这与原代码中 setHeaderRaw 的结果一致。
|
||||||
|
//
|
||||||
|
// clientHeaders 是客户端原始 HTTP header(通常为 c.Request.Header);nil 时按“客户端
|
||||||
|
// 未传”处理。body 是已经 metadata 重写 / billing version sync 之后但未 sanitize 上游
|
||||||
|
// 不兼容字段之前的版本。
|
||||||
|
func (s *GatewayService) computeFinalAnthropicBeta(
|
||||||
|
tokenType string,
|
||||||
|
mimicClaudeCode bool,
|
||||||
|
modelID string,
|
||||||
|
clientHeaders http.Header,
|
||||||
|
body []byte,
|
||||||
|
effectiveDropSet map[string]struct{},
|
||||||
|
) (string, bool) {
|
||||||
|
clientBeta := ""
|
||||||
|
if clientHeaders != nil {
|
||||||
|
clientBeta = getHeaderRaw(clientHeaders, "anthropic-beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenType == "oauth" {
|
||||||
|
if mimicClaudeCode {
|
||||||
|
// mimic 路径:原代码跳过白名单透传,incomingBeta 总是空字符串。
|
||||||
|
// 这里传空 string 以严格对齐原行为。
|
||||||
|
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||||
|
if !strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
|
requiredBetas = claude.FullClaudeCodeMimicryBetas()
|
||||||
|
}
|
||||||
|
return mergeAnthropicBetaDropping(requiredBetas, "", effectiveDropSet), true
|
||||||
|
}
|
||||||
|
// 真 Claude Code 客户端透传路径
|
||||||
|
return stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBeta), effectiveDropSet), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// API-key accounts
|
||||||
|
if clientBeta != "" {
|
||||||
|
return stripBetaTokensWithSet(clientBeta, effectiveDropSet), true
|
||||||
|
}
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||||
|
return beta, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeFinalCountTokensAnthropicBeta 是 count_tokens 路径上 anthropic-beta header 的
|
||||||
|
// 计算纯函数。语义与 computeFinalAnthropicBeta 对齐,但备份了 count_tokens 独有的
|
||||||
|
// 两条特殊规则:
|
||||||
|
//
|
||||||
|
// - OAuth mimic:requiredBetas 为 FullClaudeCodeMimicryBetas + BetaTokenCounting
|
||||||
|
// (与 messages 不同的是:不按 haiku 排除;count_tokens 始终携带 token-counting beta)
|
||||||
|
// - OAuth 透传 + 客户端未传 anthropic-beta:补齐 CountTokensBetaHeader
|
||||||
|
// - OAuth 透传 + 客户端传了:补齐 BetaTokenCounting(如果未含)
|
||||||
|
//
|
||||||
|
// 返回语义同 computeFinalAnthropicBeta。
|
||||||
|
func (s *GatewayService) computeFinalCountTokensAnthropicBeta(
|
||||||
|
tokenType string,
|
||||||
|
mimicClaudeCode bool,
|
||||||
|
modelID string,
|
||||||
|
clientHeaders http.Header,
|
||||||
|
body []byte,
|
||||||
|
effectiveDropSet map[string]struct{},
|
||||||
|
) (string, bool) {
|
||||||
|
clientBeta := ""
|
||||||
|
if clientHeaders != nil {
|
||||||
|
clientBeta = getHeaderRaw(clientHeaders, "anthropic-beta")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenType == "oauth" {
|
||||||
|
if mimicClaudeCode {
|
||||||
|
// 与原代码严格等价:original buildCountTokensRequest 在 count_tokens mimic
|
||||||
|
// 分支上**不**会跳过白名单透传(与 messages mimic 路径不同),所以
|
||||||
|
// incomingBeta = req.Header[anthropic-beta] = 客户端透传过来的 client beta。
|
||||||
|
// 重构后直接从 clientHeaders 拿同一个值,保持行为一致。
|
||||||
|
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
|
||||||
|
return mergeAnthropicBetaDropping(requiredBetas, clientBeta, effectiveDropSet), true
|
||||||
|
}
|
||||||
|
if clientBeta == "" {
|
||||||
|
return claude.CountTokensBetaHeader, true
|
||||||
|
}
|
||||||
|
beta := s.getBetaHeader(modelID, clientBeta)
|
||||||
|
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
||||||
|
beta = beta + "," + claude.BetaTokenCounting
|
||||||
|
}
|
||||||
|
return stripBetaTokensWithSet(beta, effectiveDropSet), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// API-key accounts
|
||||||
|
if clientBeta != "" {
|
||||||
|
return stripBetaTokensWithSet(clientBeta, effectiveDropSet), true
|
||||||
|
}
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
||||||
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||||
|
return beta, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
// stripBetaTokens removes the given beta tokens from a comma-separated header value.
|
// stripBetaTokens removes the given beta tokens from a comma-separated header value.
|
||||||
func stripBetaTokens(header string, tokens []string) string {
|
func stripBetaTokens(header string, tokens []string) string {
|
||||||
if header == "" || len(tokens) == 0 {
|
if header == "" || len(tokens) == 0 {
|
||||||
@ -6810,7 +6827,7 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
|
|||||||
applyClaudeOAuthHeaderDefaults(req)
|
applyClaudeOAuthHeaderDefaults(req)
|
||||||
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
|
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
|
||||||
// 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App")
|
// 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App")
|
||||||
for key, value := range claude.DefaultHeadersSnapshot() {
|
for key, value := range claude.DefaultHeaders {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -9072,9 +9089,6 @@ func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context,
|
|||||||
|
|
||||||
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
|
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
|
||||||
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
|
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
|
||||||
if account.Platform == PlatformWindsurf {
|
|
||||||
return windsurf.ResolveModel(requestedModel)
|
|
||||||
}
|
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
return mapAntigravityModel(account, requestedModel)
|
return mapAntigravityModel(account, requestedModel)
|
||||||
}
|
}
|
||||||
@ -9158,7 +9172,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
|
|
||||||
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
|
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
|
||||||
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
||||||
if account.Platform == PlatformAntigravity || account.Platform == PlatformWindsurf {
|
if account.Platform == PlatformAntigravity {
|
||||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -9434,6 +9448,16 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
|||||||
}
|
}
|
||||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||||
}
|
}
|
||||||
|
body = sanitizeCountTokensRequestBody(body)
|
||||||
|
|
||||||
|
// 同 buildUpstreamRequestAnthropicAPIKeyPassthrough:能力维度 sanitize。
|
||||||
|
clientBeta := ""
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
clientBeta = getHeaderRaw(c.Request.Header, "anthropic-beta")
|
||||||
|
}
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, clientBeta); changed {
|
||||||
|
body = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -9501,9 +9525,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
|
|
||||||
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
|
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
|
||||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||||
ctEnableFP, ctEnableMPT, _ := true, false, false
|
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
|
||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
ctEnableFP, ctEnableMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
|
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||||
}
|
}
|
||||||
var ctFingerprint *Fingerprint
|
var ctFingerprint *Fingerprint
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && s.identityService != nil {
|
||||||
@ -9525,8 +9549,23 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if ctFingerprint != nil && ctEnableFP {
|
if ctFingerprint != nil && ctEnableFP {
|
||||||
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
|
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
|
||||||
}
|
}
|
||||||
// 无占位符时函数为 no-op,故无需 ctEnableCCH gate。
|
|
||||||
body = signBillingHeaderCCH(body)
|
// === 计算最终 anthropic-beta header(先于 body sanitize 与 CCH 签名)===
|
||||||
|
// 顺序约束同 buildUpstreamRequest。
|
||||||
|
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
|
||||||
|
finalBetaHeader, finalBetaShouldSet := s.computeFinalCountTokensAnthropicBeta(
|
||||||
|
tokenType, mimicClaudeCode, modelID, clientHeaders, body, ctEffectiveDropSet,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 能力维度 body sanitize:与最终 anthropic-beta header 对称
|
||||||
|
if sanitized, changed := sanitizeAnthropicBodyForBetaTokens(body, finalBetaHeader); changed {
|
||||||
|
body = sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctEnableCCH {
|
||||||
|
body = signBillingHeaderCCH(body)
|
||||||
|
}
|
||||||
|
body = sanitizeCountTokensRequestBody(body)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -9567,50 +9606,24 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
applyClaudeOAuthHeaderDefaults(req)
|
applyClaudeOAuthHeaderDefaults(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
// OAuth + mimic Claude Code:强制注入 CLI 指纹 header
|
||||||
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
|
applyClaudeCodeMimicHeaders(req, false)
|
||||||
// OAuth 账号:处理 anthropic-beta header
|
|
||||||
if tokenType == "oauth" {
|
|
||||||
if mimicClaudeCode {
|
|
||||||
applyClaudeCodeMimicHeaders(req, false)
|
|
||||||
|
|
||||||
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
|
||||||
requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
|
|
||||||
} else {
|
|
||||||
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
|
||||||
if clientBetaHeader == "" {
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader)
|
|
||||||
} else {
|
|
||||||
beta := s.getBetaHeader(modelID, clientBetaHeader)
|
|
||||||
if !strings.Contains(beta, claude.BetaTokenCounting) {
|
|
||||||
beta = beta + "," + claude.BetaTokenCounting
|
|
||||||
}
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// API-key accounts: apply beta policy filter to strip controlled tokens
|
|
||||||
if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" {
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet))
|
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey {
|
|
||||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
|
||||||
if requestNeedsBetaFeatures(body) {
|
|
||||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", beta)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// X-Claude-Code-Session-Id 头处理(count_tokens 路径):
|
// 写入最终 anthropic-beta header(Del 一次避免白名单透传值残留)
|
||||||
// 与 messages 路径保持同样逻辑,OAuth mimic 场景缺失时兜底 UUID。
|
deleteHeaderAllForms(req.Header, "anthropic-beta")
|
||||||
ensureClaudeCodeSessionID(req, body, tokenType, mimicClaudeCode)
|
if finalBetaShouldSet {
|
||||||
|
setHeaderRaw(req.Header, "anthropic-beta", finalBetaHeader)
|
||||||
|
}
|
||||||
|
|
||||||
// x-client-request-id(count_tokens 路径)
|
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||||
|
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||||
|
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c != nil && tokenType == "oauth" {
|
if c != nil && tokenType == "oauth" {
|
||||||
@ -9623,6 +9636,25 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeCountTokensRequestBody(body []byte) []byte {
|
||||||
|
out := body
|
||||||
|
for _, path := range []string{
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"stream",
|
||||||
|
"stop_sequences",
|
||||||
|
"stop",
|
||||||
|
} {
|
||||||
|
if gjson.GetBytes(out, path).Exists() {
|
||||||
|
if next, ok := deleteJSONPathBytes(out, path); ok {
|
||||||
|
out = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// countTokensError 返回 count_tokens 错误响应
|
// countTokensError 返回 count_tokens 错误响应
|
||||||
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
|
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
|
|||||||
@ -2031,6 +2031,22 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
|
|||||||
parts := extractGeminiParts(geminiResp)
|
parts := extractGeminiParts(geminiResp)
|
||||||
for _, part := range parts {
|
for _, part := range parts {
|
||||||
if text, ok := part["text"].(string); ok && text != "" {
|
if text, ok := part["text"].(string); ok && text != "" {
|
||||||
|
// Close an open tool_use block before starting text, mirroring
|
||||||
|
// the functionCall branch (which closes open text blocks) and
|
||||||
|
// the chat-completions sibling's closeOpenTool(). Otherwise a
|
||||||
|
// tool→text sequence keeps the tool_use block open while the
|
||||||
|
// text block starts, emitting overlapping Anthropic content
|
||||||
|
// blocks that violate the SSE contract.
|
||||||
|
if openToolIndex >= 0 {
|
||||||
|
writeSSE(c.Writer, "content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": openToolIndex,
|
||||||
|
})
|
||||||
|
openToolIndex = -1
|
||||||
|
openToolName = ""
|
||||||
|
seenToolJSON = ""
|
||||||
|
}
|
||||||
|
|
||||||
delta, newSeen := computeGeminiTextDelta(seenText, text)
|
delta, newSeen := computeGeminiTextDelta(seenText, text)
|
||||||
seenText = newSeen
|
seenText = newSeen
|
||||||
if delta == "" {
|
if delta == "" {
|
||||||
|
|||||||
@ -832,3 +832,108 @@ func TestParseGeminiRateLimitResetTime(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGeminiMessagesHandleStreamingResponse_ClosesToolBlockBeforeText guards the
|
||||||
|
// tool→text ordering in the Gemini→Anthropic (messages) streaming bridge. When
|
||||||
|
// Gemini emits a functionCall part followed by a text part, the tool_use content
|
||||||
|
// block must be closed before the text block opens; otherwise the Anthropic SSE
|
||||||
|
// stream contains overlapping content blocks. The chat-completions sibling
|
||||||
|
// already enforces this via closeOpenTool().
|
||||||
|
func TestGeminiMessagesHandleStreamingResponse_ClosesToolBlockBeforeText(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
upstreamBody := `data: {"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"city":"SF"}}}]}}]}` + "\n\n" +
|
||||||
|
`data: {"candidates":[{"content":{"parts":[{"text":"All done."}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}` + "\n\n" +
|
||||||
|
"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{}
|
||||||
|
result, err := svc.handleStreamingResponse(c, resp, time.Now(), "claude-3-5-sonnet")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
events := parseAnthropicContentBlockEvents(t, rec.Body.String())
|
||||||
|
|
||||||
|
// Anthropic allows at most one content block open at a time: every
|
||||||
|
// content_block_start must be matched by a content_block_stop before the
|
||||||
|
// next start. Replay the lifecycle and assert there is no overlap.
|
||||||
|
open := -1
|
||||||
|
blockTypes := map[int]string{}
|
||||||
|
textStarted := false
|
||||||
|
toolClosed := false
|
||||||
|
toolClosedBeforeText := false
|
||||||
|
for _, ev := range events {
|
||||||
|
switch ev.event {
|
||||||
|
case "content_block_start":
|
||||||
|
require.Equalf(t, -1, open,
|
||||||
|
"content block %d opened while block %d was still open (overlapping blocks)", ev.index, open)
|
||||||
|
open = ev.index
|
||||||
|
blockTypes[ev.index] = ev.blockType
|
||||||
|
if ev.blockType == "text" {
|
||||||
|
textStarted = true
|
||||||
|
if toolClosed {
|
||||||
|
toolClosedBeforeText = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "content_block_stop":
|
||||||
|
require.Equalf(t, open, ev.index,
|
||||||
|
"content_block_stop index %d does not match the open block %d", ev.index, open)
|
||||||
|
if blockTypes[ev.index] == "tool_use" {
|
||||||
|
toolClosed = true
|
||||||
|
}
|
||||||
|
open = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, textStarted, "expected a text content block to be emitted after the tool call")
|
||||||
|
require.True(t, toolClosedBeforeText, "tool_use block must be closed before the text block starts")
|
||||||
|
require.Equal(t, -1, open, "stream ended with a content block still open")
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicContentBlockEvent struct {
|
||||||
|
event string
|
||||||
|
index int
|
||||||
|
blockType string
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAnthropicContentBlockEvents extracts content_block_start/stop events (with
|
||||||
|
// their index and, for starts, the content block type) from an Anthropic SSE body.
|
||||||
|
func parseAnthropicContentBlockEvents(t *testing.T, raw string) []anthropicContentBlockEvent {
|
||||||
|
t.Helper()
|
||||||
|
var events []anthropicContentBlockEvent
|
||||||
|
for _, chunk := range strings.Split(raw, "\n\n") {
|
||||||
|
var eventName, dataLine string
|
||||||
|
for _, line := range strings.Split(chunk, "\n") {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(line, "event:"):
|
||||||
|
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||||
|
case strings.HasPrefix(line, "data:"):
|
||||||
|
dataLine = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if eventName != "content_block_start" && eventName != "content_block_stop" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
ContentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"content_block"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
|
||||||
|
events = append(events, anthropicContentBlockEvent{
|
||||||
|
event: eventName,
|
||||||
|
index: payload.Index,
|
||||||
|
blockType: payload.ContentBlock.Type,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|||||||
@ -117,6 +117,20 @@ func addHeaderRaw(h http.Header, key, value string) {
|
|||||||
h[key] = append(h[key], value)
|
h[key] = append(h[key], value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteHeaderAllForms removes a header in all common key forms (raw, wire casing,
|
||||||
|
// canonical) so subsequent setHeaderRaw will not coexist with a passthrough value
|
||||||
|
// written under a different casing.
|
||||||
|
func deleteHeaderAllForms(h http.Header, key string) {
|
||||||
|
if h == nil || key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Del(key) // canonical
|
||||||
|
delete(h, key)
|
||||||
|
if wk := resolveWireCasing(key); wk != key {
|
||||||
|
delete(h, wk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
|
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
|
||||||
// between Go canonical keys, wire casing keys, and raw keys:
|
// between Go canonical keys, wire casing keys, and raw keys:
|
||||||
// 1. exact key as provided
|
// 1. exact key as provided
|
||||||
|
|||||||
@ -44,6 +44,7 @@ type OpenAIAccountScheduleRequest struct {
|
|||||||
PreviousResponseID string
|
PreviousResponseID string
|
||||||
RequestedModel string
|
RequestedModel string
|
||||||
RequiredTransport OpenAIUpstreamTransport
|
RequiredTransport OpenAIUpstreamTransport
|
||||||
|
RequiredCapability OpenAIEndpointCapability
|
||||||
RequiredImageCapability OpenAIImagesCapability
|
RequiredImageCapability OpenAIImagesCapability
|
||||||
RequireCompact bool
|
RequireCompact bool
|
||||||
ExcludedIDs map[int64]struct{}
|
ExcludedIDs map[int64]struct{}
|
||||||
@ -263,12 +264,13 @@ func (s *defaultOpenAIAccountScheduler) Select(
|
|||||||
|
|
||||||
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
|
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
|
||||||
if previousResponseID != "" {
|
if previousResponseID != "" {
|
||||||
selection, err := s.service.SelectAccountByPreviousResponseID(
|
selection, err := s.service.selectAccountByPreviousResponseIDForCapability(
|
||||||
ctx,
|
ctx,
|
||||||
req.GroupID,
|
req.GroupID,
|
||||||
previousResponseID,
|
previousResponseID,
|
||||||
req.RequestedModel,
|
req.RequestedModel,
|
||||||
req.ExcludedIDs,
|
req.ExcludedIDs,
|
||||||
|
req.RequiredCapability,
|
||||||
req.RequireCompact,
|
req.RequireCompact,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -363,12 +365,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
|||||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
|
account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact, req.RequiredCapability)
|
||||||
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if acquireErr == nil && result != nil && result.Acquired {
|
if acquireErr == nil && result != nil && result.Acquired {
|
||||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||||
@ -794,11 +795,11 @@ func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
|
|||||||
compactBlocked := false
|
compactBlocked := false
|
||||||
for i := 0; i < len(selectionOrder); i++ {
|
for i := 0; i < len(selectionOrder); i++ {
|
||||||
candidate := selectionOrder[i]
|
candidate := selectionOrder[i]
|
||||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -934,11 +935,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
cfg := s.service.schedulingConfig()
|
cfg := s.service.schedulingConfig()
|
||||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||||
for _, candidate := range selectionOrder {
|
for _, candidate := range selectionOrder {
|
||||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -977,6 +978,13 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
|
|||||||
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
|
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
// Quota auto-pause must be evaluated during the initial filter too. Without it the
|
||||||
|
// TopK candidate pool can be filled with paused accounts and the later fresh/DB
|
||||||
|
// rechecks won't reach healthy accounts that fell outside TopK — manifesting as
|
||||||
|
// "no available accounts" even though healthy ones exist.
|
||||||
|
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -985,7 +993,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
|
|||||||
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
|
s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
return accountSupportsOpenAICapabilities(account, req.RequiredCapability, req.RequiredImageCapability)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||||
@ -1108,7 +1116,21 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
|||||||
requiredTransport OpenAIUpstreamTransport,
|
requiredTransport OpenAIUpstreamTransport,
|
||||||
requireCompact bool,
|
requireCompact bool,
|
||||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
|
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", "", requireCompact)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForCapability(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
previousResponseID string,
|
||||||
|
sessionHash string,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
requiredTransport OpenAIUpstreamTransport,
|
||||||
|
requiredCapability OpenAIEndpointCapability,
|
||||||
|
requireCompact bool,
|
||||||
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
|
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, requiredCapability, "", requireCompact)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
||||||
@ -1119,13 +1141,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
|||||||
excludedIDs map[int64]struct{},
|
excludedIDs map[int64]struct{},
|
||||||
requiredCapability OpenAIImagesCapability,
|
requiredCapability OpenAIImagesCapability,
|
||||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
|
selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", requiredCapability, false)
|
||||||
if err == nil && selection != nil && selection.Account != nil {
|
if err == nil && selection != nil && selection.Account != nil {
|
||||||
return selection, decision, nil
|
return selection, decision, nil
|
||||||
}
|
}
|
||||||
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
|
// 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
|
||||||
if requiredCapability == OpenAIImagesCapabilityNative {
|
if requiredCapability == OpenAIImagesCapabilityNative {
|
||||||
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
|
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", OpenAIImagesCapabilityBasic, false)
|
||||||
}
|
}
|
||||||
return selection, decision, err
|
return selection, decision, err
|
||||||
}
|
}
|
||||||
@ -1138,9 +1160,11 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
|||||||
requestedModel string,
|
requestedModel string,
|
||||||
excludedIDs map[int64]struct{},
|
excludedIDs map[int64]struct{},
|
||||||
requiredTransport OpenAIUpstreamTransport,
|
requiredTransport OpenAIUpstreamTransport,
|
||||||
|
requiredCapability OpenAIEndpointCapability,
|
||||||
requiredImageCapability OpenAIImagesCapability,
|
requiredImageCapability OpenAIImagesCapability,
|
||||||
requireCompact bool,
|
requireCompact bool,
|
||||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
|
ctx = s.withOpenAIQuotaAutoPauseContext(ctx)
|
||||||
decision := OpenAIAccountScheduleDecision{}
|
decision := OpenAIAccountScheduleDecision{}
|
||||||
scheduler := s.getOpenAIAccountScheduler(ctx)
|
scheduler := s.getOpenAIAccountScheduler(ctx)
|
||||||
if scheduler == nil {
|
if scheduler == nil {
|
||||||
@ -1148,14 +1172,14 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
|||||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||||
for {
|
for {
|
||||||
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
|
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, decision, err
|
return nil, decision, err
|
||||||
}
|
}
|
||||||
if selection == nil || selection.Account == nil {
|
if selection == nil || selection.Account == nil {
|
||||||
return selection, decision, nil
|
return selection, decision, nil
|
||||||
}
|
}
|
||||||
if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
|
if accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) {
|
||||||
return selection, decision, nil
|
return selection, decision, nil
|
||||||
}
|
}
|
||||||
if selection.ReleaseFunc != nil {
|
if selection.ReleaseFunc != nil {
|
||||||
@ -1173,14 +1197,15 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
|||||||
|
|
||||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||||
for {
|
for {
|
||||||
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
|
selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, decision, err
|
return nil, decision, err
|
||||||
}
|
}
|
||||||
if selection == nil || selection.Account == nil {
|
if selection == nil || selection.Account == nil {
|
||||||
return selection, decision, nil
|
return selection, decision, nil
|
||||||
}
|
}
|
||||||
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
|
if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) &&
|
||||||
|
accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) {
|
||||||
return selection, decision, nil
|
return selection, decision, nil
|
||||||
}
|
}
|
||||||
if selection.ReleaseFunc != nil {
|
if selection.ReleaseFunc != nil {
|
||||||
@ -1217,12 +1242,21 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
|||||||
PreviousResponseID: previousResponseID,
|
PreviousResponseID: previousResponseID,
|
||||||
RequestedModel: requestedModel,
|
RequestedModel: requestedModel,
|
||||||
RequiredTransport: requiredTransport,
|
RequiredTransport: requiredTransport,
|
||||||
|
RequiredCapability: requiredCapability,
|
||||||
RequiredImageCapability: requiredImageCapability,
|
RequiredImageCapability: requiredImageCapability,
|
||||||
RequireCompact: requireCompact,
|
RequireCompact: requireCompact,
|
||||||
ExcludedIDs: excludedIDs,
|
ExcludedIDs: excludedIDs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func accountSupportsOpenAICapabilities(account *Account, requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability) bool {
|
||||||
|
if account == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return account.SupportsOpenAIEndpointCapability(requiredCapability) &&
|
||||||
|
account.SupportsOpenAIImageCapability(requiredImageCapability)
|
||||||
|
}
|
||||||
|
|
||||||
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
|
func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
|
||||||
if len(excludedIDs) == 0 {
|
if len(excludedIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -417,6 +417,64 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require
|
|||||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) {
|
||||||
|
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(10110)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 36031,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 36032,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 5,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions", "embeddings"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||||
|
cache: &schedulerTestGatewayCache{},
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
|
||||||
|
ctx,
|
||||||
|
&groupID,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
nil,
|
||||||
|
OpenAIUpstreamTransportHTTPSSE,
|
||||||
|
OpenAIEndpointCapabilityEmbeddings,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.NotNil(t, selection.Account)
|
||||||
|
require.Equal(t, int64(36032), selection.Account.ID)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
|
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
|
||||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||||
|
|
||||||
@ -482,6 +540,141 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev
|
|||||||
require.True(t, decision.StickyPreviousHit)
|
require.True(t, decision.StickyPreviousHit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) {
|
||||||
|
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(10111)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 37011,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 37012,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 5,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions", "embeddings"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||||
|
cache: &schedulerTestGatewayCache{},
|
||||||
|
cfg: cfg,
|
||||||
|
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||||
|
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
|
||||||
|
ctx,
|
||||||
|
&groupID,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
nil,
|
||||||
|
OpenAIUpstreamTransportHTTPSSE,
|
||||||
|
OpenAIEndpointCapabilityEmbeddings,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.NotNil(t, selection.Account)
|
||||||
|
require.Equal(t, int64(37012), selection.Account.ID)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
require.Equal(t, 1, decision.CandidateCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyStickyBindings(t *testing.T) {
|
||||||
|
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(10112)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 37021,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 37022,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 5,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions", "embeddings"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := newSchedulerTestOpenAIWSV2Config()
|
||||||
|
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||||
|
cache := &schedulerTestGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{
|
||||||
|
"openai:session_hash_embeddings": 37021,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||||
|
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
store := svc.getOpenAIWSStateStore()
|
||||||
|
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_embeddings_chat_only", 37021, time.Hour))
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithSchedulerForCapability(
|
||||||
|
ctx,
|
||||||
|
&groupID,
|
||||||
|
"resp_embeddings_chat_only",
|
||||||
|
"session_hash_embeddings",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
nil,
|
||||||
|
OpenAIUpstreamTransportHTTPSSE,
|
||||||
|
OpenAIEndpointCapabilityEmbeddings,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.NotNil(t, selection.Account)
|
||||||
|
require.Equal(t, int64(37022), selection.Account.ID)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
require.False(t, decision.StickyPreviousHit)
|
||||||
|
require.False(t, decision.StickySessionHit)
|
||||||
|
require.Equal(t, int64(37022), cache.sessionBindings["openai:session_hash_embeddings"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
|
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
|
||||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||||
|
|
||||||
@ -522,6 +715,224 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
|
|||||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AutoPauseBy5hThreshold(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
primary := Account{
|
||||||
|
ID: 35001,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 95.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35002), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AllowsBelow5hThreshold(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
primary := Account{
|
||||||
|
ID: 35101,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 80.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35101), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_AutoPauseBy7dThreshold(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
primary := Account{
|
||||||
|
ID: 35201,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_7d_used_percent": 95.0,
|
||||||
|
"auto_pause_7d_threshold": 0.95,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35202, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35202), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_UnconfiguredThresholdKeepsLegacyBehavior(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
primary := Account{ID: 35301, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, Extra: map[string]any{"codex_5h_used_percent": 99.0, "codex_7d_used_percent": 99.0}}
|
||||||
|
secondary := Account{ID: 35302, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35301), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_UsesGlobalDefaultThreshold(t *testing.T) {
|
||||||
|
ctx := withOpenAIQuotaAutoPauseSettings(context.Background(), OpsOpenAIAccountQuotaAutoPauseSettings{DefaultThreshold5h: 0.95})
|
||||||
|
primary := Account{
|
||||||
|
ID: 35401,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 95.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35402, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35402), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression: a per-account explicit-disable flag exempts the account from auto-pause
|
||||||
|
// even when a global default threshold is set. Without this, "leave threshold blank"
|
||||||
|
// silently falls back to global default and admins have no way to whitelist a single
|
||||||
|
// account.
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_PerAccountDisableOverridesGlobalDefault(t *testing.T) {
|
||||||
|
ctx := withOpenAIQuotaAutoPauseSettings(context.Background(), OpsOpenAIAccountQuotaAutoPauseSettings{DefaultThreshold5h: 0.95})
|
||||||
|
// Account has high usage AND no per-account threshold (would normally fall back to
|
||||||
|
// the global default and get paused), but the explicit disable flag is set.
|
||||||
|
primary := Account{
|
||||||
|
ID: 35701,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 99.0,
|
||||||
|
"auto_pause_5h_disabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35702, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35701), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable is per-window: disabling only 5h must still allow 7d auto-pause to fire.
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_PerWindowDisableScoped(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
primary := Account{
|
||||||
|
ID: 35801,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 99.0,
|
||||||
|
"codex_7d_used_percent": 99.0,
|
||||||
|
"auto_pause_5h_disabled": true,
|
||||||
|
"auto_pause_7d_threshold": 0.95,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35802), account.ID, "7d auto-pause must still fire even though 5h is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_StaleUsageWindowResetSkipsPause(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
// Usage is over threshold but the window's reset time has already passed, so the
|
||||||
|
// cached percentage is stale (the real window rolled over) and the account must NOT
|
||||||
|
// stay paused — otherwise it could be skipped forever with no traffic to refresh it.
|
||||||
|
primary := Account{
|
||||||
|
ID: 35501,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 99.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
"codex_5h_reset_at": time.Now().Add(-time.Minute).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35502, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35501), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_FreshUsageWindowStillPauses(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
// Same as above but the window has not reset yet, so the account stays paused.
|
||||||
|
primary := Account{
|
||||||
|
ID: 35601,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 99.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
"codex_5h_reset_at": time.Now().Add(time.Hour).Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
secondary := Account{ID: 35602, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||||
|
svc := &OpenAIGatewayService{accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, cfg: &config.Config{}}
|
||||||
|
|
||||||
|
account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.1", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, account)
|
||||||
|
require.Equal(t, int64(35602), account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRateLimitedSnapshotCandidate(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
groupID := int64(10102)
|
groupID := int64(10102)
|
||||||
@ -1069,6 +1480,85 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Regression: TopK initial filter must drop quota-auto-paused accounts. Otherwise
|
||||||
|
// the candidate pool is filled with paused accounts, healthy accounts fall outside
|
||||||
|
// TopK, and the scheduler returns "no available accounts" even though healthy ones
|
||||||
|
// exist.
|
||||||
|
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKExcludesQuotaPaused(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(110)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 37001,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 0,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 96.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 37002,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Priority: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.OpenAIWS.LBTopK = 1 // TopK=1 makes the bug fatal: paused account would crowd out the healthy one entirely
|
||||||
|
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
|
||||||
|
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
|
||||||
|
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
|
||||||
|
|
||||||
|
concurrencyCache := schedulerTestConcurrencyCache{
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
37001: {AccountID: 37001, LoadRate: 5, WaitingCount: 0},
|
||||||
|
37002: {AccountID: 37002, LoadRate: 5, WaitingCount: 0},
|
||||||
|
},
|
||||||
|
acquireResults: map[int64]bool{
|
||||||
|
37002: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||||
|
cache: &schedulerTestGatewayCache{},
|
||||||
|
cfg: cfg,
|
||||||
|
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||||
|
ctx,
|
||||||
|
&groupID,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"gpt-5.1",
|
||||||
|
nil,
|
||||||
|
OpenAIUpstreamTransportAny,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, selection)
|
||||||
|
require.NotNil(t, selection.Account)
|
||||||
|
require.Equal(t, int64(37002), selection.Account.ID)
|
||||||
|
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||||
|
// Only the healthy account should ever enter the candidate pool; the paused one
|
||||||
|
// must be filtered out at the initial-filter stage.
|
||||||
|
require.Equal(t, 1, decision.CandidateCount)
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
groupID := int64(12)
|
groupID := int64(12)
|
||||||
|
|||||||
@ -13,6 +13,10 @@ const (
|
|||||||
CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched"
|
CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched"
|
||||||
// CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。
|
// CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。
|
||||||
CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched"
|
CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched"
|
||||||
|
// CodexClientRestrictionReasonMatchedAllowedClient 表示请求命中账号级额外放行的命名客户端预设。
|
||||||
|
CodexClientRestrictionReasonMatchedAllowedClient = "allowed_client_matched"
|
||||||
|
// CodexClientRestrictionReasonMatchedGlobalAllowedClient 表示请求命中全局额外放行的命名客户端预设。
|
||||||
|
CodexClientRestrictionReasonMatchedGlobalAllowedClient = "global_allowed_client_matched"
|
||||||
// CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。
|
// CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。
|
||||||
CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched"
|
CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched"
|
||||||
// CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。
|
// CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。
|
||||||
@ -28,7 +32,7 @@ type CodexClientRestrictionDetectionResult struct {
|
|||||||
|
|
||||||
// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。
|
// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。
|
||||||
type CodexClientRestrictionDetector interface {
|
type CodexClientRestrictionDetector interface {
|
||||||
Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult
|
Detect(c *gin.Context, account *Account, globalAllowedClients []string) CodexClientRestrictionDetectionResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。
|
// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。
|
||||||
@ -40,7 +44,7 @@ func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexCli
|
|||||||
return &OpenAICodexClientRestrictionDetector{cfg: cfg}
|
return &OpenAICodexClientRestrictionDetector{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account, globalAllowedClients []string) CodexClientRestrictionDetectionResult {
|
||||||
if account == nil || !account.IsCodexCLIOnlyEnabled() {
|
if account == nil || !account.IsCodexCLIOnlyEnabled() {
|
||||||
return CodexClientRestrictionDetectionResult{
|
return CodexClientRestrictionDetectionResult{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
@ -78,6 +82,26 @@ func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 官方客户端白名单未命中时,先尝试账号级额外放行的命名客户端预设(如 Claude Code codex 插件)。
|
||||||
|
if allowed := account.GetCodexCLIOnlyAllowedClients(); len(allowed) > 0 &&
|
||||||
|
openai.MatchAllowedClients(userAgent, originator, allowed) {
|
||||||
|
return CodexClientRestrictionDetectionResult{
|
||||||
|
Enabled: true,
|
||||||
|
Matched: true,
|
||||||
|
Reason: CodexClientRestrictionReasonMatchedAllowedClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 再尝试由更高作用域(全局设置)注入的额外放行客户端列表。
|
||||||
|
if len(globalAllowedClients) > 0 &&
|
||||||
|
openai.MatchAllowedClients(userAgent, originator, globalAllowedClients) {
|
||||||
|
return CodexClientRestrictionDetectionResult{
|
||||||
|
Enabled: true,
|
||||||
|
Matched: true,
|
||||||
|
Reason: CodexClientRestrictionReasonMatchedGlobalAllowedClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return CodexClientRestrictionDetectionResult{
|
return CodexClientRestrictionDetectionResult{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Matched: false,
|
Matched: false,
|
||||||
|
|||||||
@ -30,7 +30,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account)
|
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account, nil)
|
||||||
require.False(t, result.Enabled)
|
require.False(t, result.Enabled)
|
||||||
require.False(t, result.Matched)
|
require.False(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
|
||||||
@ -44,7 +44,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account)
|
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||||
@ -58,7 +58,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account)
|
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||||
@ -72,7 +72,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account)
|
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||||
@ -86,7 +86,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account)
|
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason)
|
||||||
@ -100,7 +100,7 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
|
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.False(t, result.Matched)
|
require.False(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||||
@ -116,9 +116,146 @@ func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
|||||||
Extra: map[string]any{"codex_cli_only": true},
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account)
|
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAICodexClientRestrictionDetector_Detect_AllowedClients(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
const (
|
||||||
|
claudeCodeUA = "Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)"
|
||||||
|
claudeCodeOriginator = "Claude Code"
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run("配置 claude_code 白名单且命中真实签名时放行", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_cli_only": true,
|
||||||
|
"codex_cli_only_allowed_clients": []any{"claude_code"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.True(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonMatchedAllowedClient, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("配置白名单但伪造 originator 仍拒绝", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_cli_only": true,
|
||||||
|
"codex_cli_only_allowed_clients": []any{"claude_code"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, "my_client"), account, nil)
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.False(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("未配置白名单时 Claude Code 签名仍拒绝", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.False(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("未开启 codex_cli_only 时白名单不参与,直接绕过", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only_allowed_clients": []any{"claude_code"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := detector.Detect(newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), account, nil)
|
||||||
|
require.False(t, result.Enabled)
|
||||||
|
require.False(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("全局列表含 claude_code + 命中签名 → 放行(global)", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
|
}
|
||||||
|
result := detector.Detect(
|
||||||
|
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)", "Claude Code"),
|
||||||
|
account,
|
||||||
|
[]string{"claude_code"},
|
||||||
|
)
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.True(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonMatchedGlobalAllowedClient, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("全局列表含 claude_code + 非签名 → 403", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
|
}
|
||||||
|
result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account, []string{"claude_code"})
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.False(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("全局列表为空 + 账号未配 → 403", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"codex_cli_only": true},
|
||||||
|
}
|
||||||
|
result := detector.Detect(
|
||||||
|
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos) (Claude Code; 1.0.4)", "Claude Code"),
|
||||||
|
account,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.True(t, result.Enabled)
|
||||||
|
require.False(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("账号白名单优先于全局列表(reason=account)", func(t *testing.T) {
|
||||||
|
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_cli_only": true,
|
||||||
|
"codex_cli_only_allowed_clients": []any{"claude_code"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := detector.Detect(
|
||||||
|
newCodexDetectorTestContext("Claude Code/0.5.0 (Macos) (Claude Code; 1.0.4)", "Claude Code"),
|
||||||
|
account,
|
||||||
|
[]string{"claude_code"},
|
||||||
|
)
|
||||||
|
require.True(t, result.Matched)
|
||||||
|
require.Equal(t, CodexClientRestrictionReasonMatchedAllowedClient, result.Reason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@ -901,7 +901,17 @@ func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMet
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
||||||
return s.getCodexClientRestrictionDetector().Detect(c, account)
|
var globalAllowedClients []string
|
||||||
|
if account != nil && account.IsCodexCLIOnlyEnabled() && s != nil && s.settingService != nil {
|
||||||
|
ctx := context.Background()
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
ctx = c.Request.Context()
|
||||||
|
}
|
||||||
|
if s.settingService.IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx) {
|
||||||
|
globalAllowedClients = []string{openai.AllowedClientClaudeCode}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.getCodexClientRestrictionDetector().Detect(c, account, globalAllowedClients)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAPIKeyIDFromContext(c *gin.Context) int64 {
|
func getAPIKeyIDFromContext(c *gin.Context) int64 {
|
||||||
@ -959,6 +969,7 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
|
|||||||
}
|
}
|
||||||
log := logger.FromContext(ctx).With(fields...)
|
log := logger.FromContext(ctx).With(fields...)
|
||||||
if result.Matched {
|
if result.Matched {
|
||||||
|
log.Info("OpenAI codex_cli_only 放行请求")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
||||||
@ -1279,7 +1290,7 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
|||||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
||||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0)
|
return s.selectAccountForModelWithExclusions(s.withOpenAIQuotaAutoPauseContext(ctx), groupID, sessionHash, requestedModel, excludedIDs, false, 0, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// noAvailableOpenAISelectionError builds the standard "no account available" error
|
// noAvailableOpenAISelectionError builds the standard "no account available" error
|
||||||
@ -1312,19 +1323,228 @@ func openAICompactSupportTier(account *Account) int {
|
|||||||
|
|
||||||
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
|
// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
|
||||||
// compact-support checks used during account selection.
|
// compact-support checks used during account selection.
|
||||||
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool {
|
func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) bool {
|
||||||
if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if paused, reason := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
|
||||||
|
// Debug level: this fires per-candidate on the scheduling hot path, so Info
|
||||||
|
// would amplify into log spam once several accounts cross the threshold.
|
||||||
|
slog.Debug("account_auto_paused_by_quota",
|
||||||
|
"account_id", account.ID,
|
||||||
|
"window", reason.window,
|
||||||
|
"threshold", reason.threshold,
|
||||||
|
"utilization", reason.utilization,
|
||||||
|
)
|
||||||
|
return false
|
||||||
|
}
|
||||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if !account.SupportsOpenAIEndpointCapability(requiredCapability) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if requireCompact && openAICompactSupportTier(account) == 0 {
|
if requireCompact && openAICompactSupportTier(account) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIQuotaAutoPauseDecision struct {
|
||||||
|
window string
|
||||||
|
threshold float64
|
||||||
|
utilization float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldAutoPauseOpenAIAccountByQuota(ctx context.Context, account *Account) (bool, openAIQuotaAutoPauseDecision) {
|
||||||
|
if account == nil || !account.IsOpenAI() {
|
||||||
|
return false, openAIQuotaAutoPauseDecision{}
|
||||||
|
}
|
||||||
|
// Per-account explicit-disable flags must take precedence over the global default.
|
||||||
|
// Without these, leaving the account threshold blank means "use global default",
|
||||||
|
// so an admin has no way to exempt a single account from auto-pause once a global
|
||||||
|
// default exists. The disable flag is per-window so an account can opt out of
|
||||||
|
// only 5h or only 7d auto-pause.
|
||||||
|
disabled5h := resolveAccountExtraBool(account.Extra, "auto_pause_5h_disabled")
|
||||||
|
disabled7d := resolveAccountExtraBool(account.Extra, "auto_pause_7d_disabled")
|
||||||
|
threshold5h, threshold7d := resolveOpenAIQuotaAutoPauseThresholds(ctx, account)
|
||||||
|
now := time.Now()
|
||||||
|
if !disabled5h && threshold5h > 0 {
|
||||||
|
if utilization, ok := resolveOpenAIQuotaUtilization(account.Extra, "5h", now); ok && utilization >= threshold5h {
|
||||||
|
return true, openAIQuotaAutoPauseDecision{window: "5h", threshold: threshold5h, utilization: utilization}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !disabled7d && threshold7d > 0 {
|
||||||
|
if utilization, ok := resolveOpenAIQuotaUtilization(account.Extra, "7d", now); ok && utilization >= threshold7d {
|
||||||
|
return true, openAIQuotaAutoPauseDecision{window: "7d", threshold: threshold7d, utilization: utilization}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, openAIQuotaAutoPauseDecision{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveAccountExtraBool reads a bool-like value from account extra, tolerating
|
||||||
|
// the few shapes JSON unmarshalling may produce (real bool, "true"/"false"
|
||||||
|
// strings, 0/1 numbers).
|
||||||
|
func resolveAccountExtraBool(extra map[string]any, key string) bool {
|
||||||
|
if len(extra) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
value, ok := extra[key]
|
||||||
|
if !ok || value == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseBool(strings.TrimSpace(v))
|
||||||
|
return err == nil && parsed
|
||||||
|
case float64:
|
||||||
|
return v != 0
|
||||||
|
case float32:
|
||||||
|
return v != 0
|
||||||
|
case int:
|
||||||
|
return v != 0
|
||||||
|
case int64:
|
||||||
|
return v != 0
|
||||||
|
case json.Number:
|
||||||
|
if i, err := v.Int64(); err == nil {
|
||||||
|
return i != 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenAIQuotaAutoPauseThresholds(ctx context.Context, account *Account) (float64, float64) {
|
||||||
|
threshold5h, _ := resolveAccountExtraNumber(account.Extra, "auto_pause_5h_threshold")
|
||||||
|
threshold7d, _ := resolveAccountExtraNumber(account.Extra, "auto_pause_7d_threshold")
|
||||||
|
threshold5h = clamp01(threshold5h)
|
||||||
|
threshold7d = clamp01(threshold7d)
|
||||||
|
if threshold5h > 0 && threshold7d > 0 {
|
||||||
|
return threshold5h, threshold7d
|
||||||
|
}
|
||||||
|
settings := openAIQuotaAutoPauseSettingsFromContext(ctx)
|
||||||
|
if threshold5h <= 0 {
|
||||||
|
threshold5h = clamp01(settings.DefaultThreshold5h)
|
||||||
|
}
|
||||||
|
if threshold7d <= 0 {
|
||||||
|
threshold7d = clamp01(settings.DefaultThreshold7d)
|
||||||
|
}
|
||||||
|
return threshold5h, threshold7d
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveAccountExtraNumber(extra map[string]any, keys ...string) (float64, bool) {
|
||||||
|
if len(extra) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
value, ok := extra[key]
|
||||||
|
if !ok || value == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return v, true
|
||||||
|
case float32:
|
||||||
|
return float64(v), true
|
||||||
|
case int:
|
||||||
|
return float64(v), true
|
||||||
|
case int64:
|
||||||
|
return float64(v), true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := v.Float64()
|
||||||
|
if err == nil {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
|
||||||
|
if err == nil {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveOpenAIQuotaUtilization returns the current utilization ratio (0..1) for the
|
||||||
|
// given Codex usage window. ok=false means there is no usable signal to pause on:
|
||||||
|
// either no snapshot exists, or the window has already rolled over so the cached
|
||||||
|
// percentage is stale. The stale guard matters because a paused account stops
|
||||||
|
// receiving requests, so its snapshot is never refreshed from upstream headers —
|
||||||
|
// without this check an old used_percent would keep the account paused forever even
|
||||||
|
// after the real window reset.
|
||||||
|
func resolveOpenAIQuotaUtilization(extra map[string]any, window string, now time.Time) (float64, bool) {
|
||||||
|
usedPercent := readOpenAIQuotaUsedPercent(extra, window)
|
||||||
|
if usedPercent <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if openAIQuotaWindowReset(extra, window, now) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return usedPercent / 100, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAIQuotaWindowReset reports whether the Codex usage window's reset time has
|
||||||
|
// already passed relative to now. It prefers the absolute codex_<window>_reset_at
|
||||||
|
// timestamp and falls back to codex_<window>_reset_after_seconds anchored at
|
||||||
|
// codex_usage_updated_at, mirroring AccountUsageService's window-progress logic.
|
||||||
|
func openAIQuotaWindowReset(extra map[string]any, window string, now time.Time) bool {
|
||||||
|
if len(extra) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if resetAtRaw, ok := extra["codex_"+window+"_reset_at"]; ok {
|
||||||
|
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
|
||||||
|
return !now.Before(resetAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resetAfter := parseExtraInt(extra["codex_"+window+"_reset_after_seconds"])
|
||||||
|
if resetAfter <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
base := now
|
||||||
|
if updatedRaw, ok := extra["codex_usage_updated_at"]; ok {
|
||||||
|
if updatedAt, err := parseTime(fmt.Sprint(updatedRaw)); err == nil {
|
||||||
|
base = updatedAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resetAt := base.Add(time.Duration(resetAfter) * time.Second)
|
||||||
|
return !now.Before(resetAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readOpenAIQuotaUsedPercent(extra map[string]any, window string) float64 {
|
||||||
|
if len(extra) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if value, ok := resolveAccountExtraNumber(extra, "codex_"+window+"_used_percent"); ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIQuotaAutoPauseCtxKey struct{}
|
||||||
|
|
||||||
|
func withOpenAIQuotaAutoPauseSettings(ctx context.Context, settings OpsOpenAIAccountQuotaAutoPauseSettings) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, openAIQuotaAutoPauseCtxKey{}, settings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIQuotaAutoPauseSettingsFromContext(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
|
||||||
|
if ctx == nil {
|
||||||
|
return OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
}
|
||||||
|
settings, _ := ctx.Value(openAIQuotaAutoPauseCtxKey{}).(OpsOpenAIAccountQuotaAutoPauseSettings)
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) withOpenAIQuotaAutoPauseContext(ctx context.Context) context.Context {
|
||||||
|
if s == nil || s.settingService == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return withOpenAIQuotaAutoPauseSettings(ctx, s.settingService.GetOpenAIQuotaAutoPauseSettings(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known
|
// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known
|
||||||
// compact support are tried first, followed by unknown, then explicitly unsupported.
|
// compact support are tried first, followed by unknown, then explicitly unsupported.
|
||||||
// The relative order within each tier is preserved.
|
// The relative order within each tier is preserved.
|
||||||
@ -1366,7 +1586,7 @@ func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedMode
|
|||||||
return upstreamModel
|
return upstreamModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) {
|
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) (*Account, error) {
|
||||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
slog.Warn("channel pricing restriction blocked request",
|
slog.Warn("channel pricing restriction blocked request",
|
||||||
"group_id", derefGroupID(groupID),
|
"group_id", derefGroupID(groupID),
|
||||||
@ -1376,7 +1596,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
|||||||
|
|
||||||
// 1. 尝试粘性会话命中
|
// 1. 尝试粘性会话命中
|
||||||
// Try sticky session hit
|
// Try sticky session hit
|
||||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil {
|
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability); account != nil {
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1389,7 +1609,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
|||||||
|
|
||||||
// 3. 按优先级 + LRU 选择最佳账号
|
// 3. 按优先级 + LRU 选择最佳账号
|
||||||
// Select by priority + LRU
|
// Select by priority + LRU
|
||||||
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact)
|
selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact, requiredCapability)
|
||||||
|
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
||||||
@ -1414,7 +1634,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
|||||||
//
|
//
|
||||||
// tryStickySessionHit attempts to get account from sticky session.
|
// tryStickySessionHit attempts to get account from sticky session.
|
||||||
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
||||||
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account {
|
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) *Account {
|
||||||
if sessionHash == "" {
|
if sessionHash == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1446,14 +1666,14 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
|||||||
|
|
||||||
// 验证账号是否可用于当前请求
|
// 验证账号是否可用于当前请求
|
||||||
// Verify account is usable for current request
|
// Verify account is usable for current request
|
||||||
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
|
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.isOpenAIAccountRuntimeBlocked(account) {
|
if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability)
|
||||||
if account == nil {
|
if account == nil {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
return nil
|
return nil
|
||||||
@ -1477,7 +1697,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
|||||||
// Returns nil if no available account. The second return reports whether at
|
// Returns nil if no available account. The second return reports whether at
|
||||||
// least one candidate was filtered out solely because it lacks compact support
|
// least one candidate was filtered out solely because it lacks compact support
|
||||||
// (only meaningful when requireCompact=true).
|
// (only meaningful when requireCompact=true).
|
||||||
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) {
|
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*Account, bool) {
|
||||||
var selected *Account
|
var selected *Account
|
||||||
selectedCompactTier := -1
|
selectedCompactTier := -1
|
||||||
compactBlocked := false
|
compactBlocked := false
|
||||||
@ -1492,11 +1712,11 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false)
|
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1573,10 +1793,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
|
|||||||
|
|
||||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||||
return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false)
|
return s.selectAccountWithLoadAwareness(s.withOpenAIQuotaAutoPauseContext(ctx), groupID, sessionHash, requestedModel, excludedIDs, false, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) {
|
func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*AccountSelectionResult, error) {
|
||||||
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
|
||||||
slog.Warn("channel pricing restriction blocked request",
|
slog.Warn("channel pricing restriction blocked request",
|
||||||
"group_id", derefGroupID(groupID),
|
"group_id", derefGroupID(groupID),
|
||||||
@ -1593,7 +1813,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||||
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID)
|
account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1646,8 +1866,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if clearSticky {
|
if clearSticky {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
}
|
}
|
||||||
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) {
|
if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) {
|
||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability)
|
||||||
if account == nil {
|
if account == nil {
|
||||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||||
} else if s.isOpenAIAccountRuntimeBlocked(account) {
|
} else if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||||
@ -1691,15 +1911,12 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
||||||
// re-check schedulability here so recently rate-limited/overloaded accounts
|
// re-check schedulability here so recently rate-limited/overloaded accounts
|
||||||
// are not selected again before the bucket is rebuilt.
|
// are not selected again before the bucket is rebuilt.
|
||||||
if !acc.IsSchedulable() {
|
if !isOpenAIAccountEligibleForRequest(ctx, acc, requestedModel, false, requiredCapability) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if s.isOpenAIAccountRuntimeBlocked(acc) {
|
if s.isOpenAIAccountRuntimeBlocked(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
|
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1779,11 +1996,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range selectionOrder {
|
for _, item := range selectionOrder {
|
||||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1813,11 +2030,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
ordered = prioritizeOpenAICompactAccounts(ordered)
|
ordered = prioritizeOpenAICompactAccounts(ordered)
|
||||||
}
|
}
|
||||||
for _, acc := range ordered {
|
for _, acc := range ordered {
|
||||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1858,11 +2075,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
candidates = prioritizeOpenAICompactAccounts(candidates)
|
candidates = prioritizeOpenAICompactAccounts(candidates)
|
||||||
}
|
}
|
||||||
for _, acc := range candidates {
|
for _, acc := range candidates {
|
||||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
|
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability)
|
||||||
if fresh == nil {
|
if fresh == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1910,7 +2127,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
|
func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1924,7 +2141,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
|||||||
fresh = current
|
fresh = current
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) {
|
if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact, requiredCapability) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
||||||
@ -1933,12 +2150,12 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
|||||||
return fresh
|
return fresh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
|
func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.schedulerSnapshot == nil || s.accountRepo == nil {
|
if s.schedulerSnapshot == nil || s.accountRepo == nil {
|
||||||
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) {
|
if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact, requiredCapability) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return account
|
return account
|
||||||
@ -1948,7 +2165,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
|||||||
if err != nil || latest == nil {
|
if err != nil || latest == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) {
|
if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact, requiredCapability) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||||
@ -4861,7 +5078,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
eventType := gjson.GetBytes(data, "type").String()
|
eventType := gjson.GetBytes(data, "type").String()
|
||||||
if eventType != "response.completed" && eventType != "response.done" &&
|
if eventType != "response.completed" && eventType != "response.done" && eventType != "response.failed" &&
|
||||||
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ type stubCodexRestrictionDetector struct {
|
|||||||
result CodexClientRestrictionDetectionResult
|
result CodexClientRestrictionDetectionResult
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult {
|
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account, _ []string) CodexClientRestrictionDetectionResult {
|
||||||
return s.result
|
return s.result
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) {
|
|||||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}}
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}}
|
||||||
|
|
||||||
result := got.Detect(c, account)
|
result := got.Detect(c, account, nil)
|
||||||
require.True(t, result.Enabled)
|
require.True(t, result.Enabled)
|
||||||
require.True(t, result.Matched)
|
require.True(t, result.Matched)
|
||||||
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
||||||
|
|||||||
@ -2242,6 +2242,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
|||||||
require.Equal(t, 15, usage.OutputTokens)
|
require.Equal(t, 15, usage.OutputTokens)
|
||||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
// failed 事件在部分上游路径也会携带已消耗 usage,应与 WS/passthrough 保持一致
|
||||||
|
svc.parseSSEUsage(`{"type":"response.failed","response":{"usage":{"input_tokens":17,"output_tokens":19,"input_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||||
|
require.Equal(t, 17, usage.InputTokens)
|
||||||
|
require.Equal(t, 19, usage.OutputTokens)
|
||||||
|
require.Equal(t, 6, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
|
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
|
||||||
require.Equal(t, 21, usage.InputTokens)
|
require.Equal(t, 21, usage.InputTokens)
|
||||||
require.Equal(t, 8, usage.OutputTokens)
|
require.Equal(t, 8, usage.OutputTokens)
|
||||||
|
|||||||
@ -413,6 +413,79 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
|
|||||||
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountSupportsOpenAIEndpointCapability(t *testing.T) {
|
||||||
|
t.Run("OpenAI APIKey 默认兼容 chat 和 embeddings", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OpenAI OAuth 默认仅兼容 chat", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
|
||||||
|
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("显式列表支持同时声明 chat 和 embeddings", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions", "embeddings"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("显式列表只声明 chat 时不支持 embeddings", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
|
||||||
|
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("显式 map 支持单独关闭 chat 并开启 embeddings", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": map[string]any{
|
||||||
|
"chat_completions": false,
|
||||||
|
"embeddings": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions))
|
||||||
|
require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("未知能力不应默认放行", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapability("unknown")))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
|
func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
|
||||||
require.Equal(t,
|
require.Equal(t,
|
||||||
"https://image-upstream.example/v1/images/generations",
|
"https://image-upstream.example/v1/images/generations",
|
||||||
|
|||||||
@ -48,6 +48,46 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_QuotaAutoPausedMiss(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(23)
|
||||||
|
account := Account{
|
||||||
|
ID: 77,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 2,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
"codex_5h_used_percent": 96.0,
|
||||||
|
"auto_pause_5h_threshold": 0.95,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
store := NewOpenAIWSStateStore(cache)
|
||||||
|
cfg := newOpenAIWSV2TestConfig()
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
openaiWSStateStore: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_quota", account.ID, time.Hour))
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_quota", "gpt-5.1", nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, selection, "超过 5h 配额阈值的账号不应继续命中 previous_response_id 粘连")
|
||||||
|
|
||||||
|
// Auto-pause is transient, so the binding is preserved: the chain can resume on the
|
||||||
|
// same account once the quota window resets.
|
||||||
|
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_quota")
|
||||||
|
require.NoError(t, getErr)
|
||||||
|
require.Equal(t, account.ID, boundAccountID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
|
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
groupID := int64(23)
|
groupID := int64(23)
|
||||||
@ -268,6 +308,52 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
|
|||||||
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
|
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_CapabilityMismatchKeepsSticky(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(25)
|
||||||
|
account := Account{
|
||||||
|
ID: 31,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"openai_capabilities": []any{"chat_completions"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
store := NewOpenAIWSStateStore(cache)
|
||||||
|
cfg := newOpenAIWSV2TestConfig()
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
openaiWSStateStore: store,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_capability", account.ID, time.Hour))
|
||||||
|
|
||||||
|
selection, err := svc.selectAccountByPreviousResponseIDForCapability(
|
||||||
|
ctx,
|
||||||
|
&groupID,
|
||||||
|
"resp_prev_capability",
|
||||||
|
"text-embedding-3-small",
|
||||||
|
nil,
|
||||||
|
OpenAIEndpointCapabilityEmbeddings,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, selection)
|
||||||
|
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_capability")
|
||||||
|
require.NoError(t, getErr)
|
||||||
|
require.Equal(t, account.ID, boundAccountID)
|
||||||
|
}
|
||||||
|
|
||||||
func newOpenAIWSV2TestConfig() *config.Config {
|
func newOpenAIWSV2TestConfig() *config.Config {
|
||||||
cfg := &config.Config{}
|
cfg := &config.Config{}
|
||||||
cfg.Gateway.OpenAIWS.Enabled = true
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
|||||||
@ -369,7 +369,12 @@ func openAIWSEventMayContainToolCalls(eventType string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func openAIWSEventShouldParseUsage(eventType string) bool {
|
func openAIWSEventShouldParseUsage(eventType string) bool {
|
||||||
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
|
switch strings.TrimSpace(eventType) {
|
||||||
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
|
||||||
@ -2484,6 +2489,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
imageInputSize string
|
imageInputSize string
|
||||||
payloadBytes int
|
payloadBytes int
|
||||||
}
|
}
|
||||||
|
ingressSessionOriginalModel := ""
|
||||||
|
|
||||||
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
|
||||||
next, err := sjson.SetBytes(current, path, value)
|
next, err := sjson.SetBytes(current, path, value)
|
||||||
@ -2547,12 +2553,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
originalModel := strings.TrimSpace(values[1].String())
|
originalModel := strings.TrimSpace(values[1].String())
|
||||||
|
modelMissing := originalModel == ""
|
||||||
if originalModel == "" {
|
if originalModel == "" {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
// 入站 WS 长会话里,部分客户端只在第一轮 response.create 上声明
|
||||||
coderws.StatusPolicyViolation,
|
// model,后续 turn 复用同一 session-level model。为避免因省略
|
||||||
"model is required in response.create payload",
|
// model 直接断开用户连接,这里回落到上一轮已通过校验的客户端模型,
|
||||||
nil,
|
// 并在下方写回上游 payload,保证账号模型映射/fast policy/图片权限
|
||||||
)
|
// 仍按同一模型执行。
|
||||||
|
originalModel = ingressSessionOriginalModel
|
||||||
|
if originalModel == "" {
|
||||||
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||||
|
coderws.StatusPolicyViolation,
|
||||||
|
"model is required in response.create payload",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
promptCacheKey := strings.TrimSpace(values[2].String())
|
promptCacheKey := strings.TrimSpace(values[2].String())
|
||||||
previousResponseID := strings.TrimSpace(values[3].String())
|
previousResponseID := strings.TrimSpace(values[3].String())
|
||||||
@ -2572,7 +2587,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
normalized = next
|
normalized = next
|
||||||
}
|
}
|
||||||
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
|
upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel))
|
||||||
if upstreamModel != originalModel {
|
if modelMissing || upstreamModel != originalModel {
|
||||||
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
||||||
if setErr != nil {
|
if setErr != nil {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||||||
@ -2602,11 +2617,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
// single integration point for all WS ingress turns (first + follow-up
|
// single integration point for all WS ingress turns (first + follow-up
|
||||||
// frames flow through here).
|
// frames flow through here).
|
||||||
//
|
//
|
||||||
// Model fallback: parseClientPayload above rejects any frame whose
|
// Model fallback: first turn still requires model at the handler layer;
|
||||||
// "model" field is missing (line ~2493-2500), so by the time we
|
// follow-up response.create frames may omit it and then reuse
|
||||||
// reach this point upstreamModel is always derived from a non-empty
|
// ingressSessionOriginalModel. We always write a concrete upstream model
|
||||||
// per-frame model. The capturedSessionModel fallback used in the
|
// before evaluating policy, so whitelist / filter behavior remains stable.
|
||||||
// passthrough adapter is therefore not needed in this path.
|
|
||||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||||
@ -2635,6 +2649,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
normalized = policyApplied
|
normalized = policyApplied
|
||||||
|
ingressSessionOriginalModel = originalModel
|
||||||
|
|
||||||
return openAIWSClientPayload{
|
return openAIWSClientPayload{
|
||||||
payloadRaw: normalized,
|
payloadRaw: normalized,
|
||||||
@ -3915,7 +3930,10 @@ func isOpenAIWSTokenEvent(eventType string) bool {
|
|||||||
if strings.HasPrefix(eventType, "response.output") {
|
if strings.HasPrefix(eventType, "response.output") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return eventType == "response.completed" || eventType == "response.done"
|
// 终止事件(response.completed/done/failed/...)由 isOpenAIWSTerminalEvent 单独处理。
|
||||||
|
// 不能把它们当作 token event,否则当上游没有可识别的 delta 时,
|
||||||
|
// firstTokenMs 会被填到终止时刻,等于把"总耗时"误报为"首 token 延迟"。
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
|
func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
|
||||||
@ -3987,6 +4005,18 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
|||||||
requestedModel string,
|
requestedModel string,
|
||||||
excludedIDs map[int64]struct{},
|
excludedIDs map[int64]struct{},
|
||||||
requireCompact bool,
|
requireCompact bool,
|
||||||
|
) (*AccountSelectionResult, error) {
|
||||||
|
return s.selectAccountByPreviousResponseIDForCapability(ctx, groupID, previousResponseID, requestedModel, excludedIDs, "", requireCompact)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) selectAccountByPreviousResponseIDForCapability(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
previousResponseID string,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
requiredCapability OpenAIEndpointCapability,
|
||||||
|
requireCompact bool,
|
||||||
) (*AccountSelectionResult, error) {
|
) (*AccountSelectionResult, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -4027,12 +4057,41 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
|
|||||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
if !account.SupportsOpenAIEndpointCapability(requiredCapability) {
|
||||||
if account == nil {
|
|
||||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
// 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
|
// Quota auto-pause must also gate the previous_response_id sticky path; otherwise an
|
||||||
|
// account over its 5h/7d threshold keeps serving the same response chain even though
|
||||||
|
// normal scheduling skips it. Pause is transient, so fall through to normal scheduling
|
||||||
|
// without deleting the binding (the window may reset before the next turn).
|
||||||
|
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, account); paused {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if s.schedulerSnapshot != nil && s.accountRepo != nil {
|
||||||
|
latest, latestErr := s.accountRepo.GetByID(ctx, account.ID)
|
||||||
|
if latestErr != nil || latest == nil {
|
||||||
|
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if shouldClearStickySession(latest, requestedModel) || !latest.IsOpenAI() || !latest.IsSchedulable() {
|
||||||
|
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if !latest.SupportsOpenAIEndpointCapability(requiredCapability) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if paused, _ := shouldAutoPauseOpenAIAccountByQuota(ctx, latest); paused {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||||
|
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
account = latest
|
||||||
|
}
|
||||||
if requireCompact && openAICompactSupportTier(account) == 0 {
|
if requireCompact && openAICompactSupportTier(account) == 0 {
|
||||||
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@ -39,6 +39,24 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
|
|||||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIWSEventShouldParseUsageTerminalEvents(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, eventType := range []string{
|
||||||
|
"response.completed",
|
||||||
|
"response.done",
|
||||||
|
"response.failed",
|
||||||
|
"response.incomplete",
|
||||||
|
"response.cancelled",
|
||||||
|
"response.canceled",
|
||||||
|
} {
|
||||||
|
require.True(t, openAIWSEventShouldParseUsage(eventType), eventType)
|
||||||
|
require.True(t, openAIWSEventShouldParseUsage(" "+eventType+" "), eventType)
|
||||||
|
}
|
||||||
|
require.False(t, openAIWSEventShouldParseUsage("response.output_text.delta"))
|
||||||
|
require.False(t, openAIWSEventShouldParseUsage(""))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
||||||
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||||
|
|||||||
@ -164,6 +164,140 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossT
|
|||||||
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
|
require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_FollowupCreateCanOmitModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_omit_model_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 115,
|
||||||
|
Name: "openai-ingress-omit-model",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"client-model": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"client-model","stream":false}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, firstEvent, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_omit_model_1", gjson.GetBytes(firstEvent, "response.id").String())
|
||||||
|
|
||||||
|
writeCtx, cancelWrite = context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","stream":false,"previous_response_id":"resp_omit_model_1"}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead = context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, secondEvent, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_omit_model_2", gjson.GetBytes(secondEvent, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 ingress websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Len(t, captureConn.writes, 2)
|
||||||
|
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[0]), "model").String())
|
||||||
|
require.Equal(t, "gpt-5.1", gjson.Get(requestToJSONString(captureConn.writes[1]), "model").String())
|
||||||
|
require.Equal(t, "resp_omit_model_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@ -441,6 +575,124 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
|
|||||||
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
require.Len(t, upstreamConn.writes, 1, "passthrough 模式应透传首条 response.create")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughHeadersUsePromptCacheAndTurnState(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||||
|
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||||
|
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||||
|
|
||||||
|
upstreamConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_passthrough_headers","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: upstreamConn}
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPassthroughDialer: captureDialer,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 453,
|
||||||
|
Name: "openai-ingress-passthrough-headers",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
req := r.Clone(r.Context())
|
||||||
|
req.Header = req.Header.Clone()
|
||||||
|
req.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
|
||||||
|
req.Header.Set(openAIWSTurnStateHeader, "turn-state-1")
|
||||||
|
req.Header.Set(openAIWSTurnMetadataHeader, "turn-meta-1")
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||||
|
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if readErr != nil {
|
||||||
|
serverErrCh <- readErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "oauth-token", firstMessage, nil)
|
||||||
|
}))
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"prompt_cache_key":"pcache_passthrough"}`))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, event, readErr := clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.NoError(t, readErr)
|
||||||
|
require.Equal(t, "resp_passthrough_headers", gjson.GetBytes(event, "response.id").String())
|
||||||
|
_ = clientConn.Close(coderws.StatusNormalClosure, "done")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case serverErr := <-serverErrCh:
|
||||||
|
if serverErr != nil {
|
||||||
|
require.Contains(t, serverErr.Error(), "StatusNormalClosure")
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("等待 passthrough websocket 结束超时")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, isolateOpenAISessionID(0, "pcache_passthrough"), captureDialer.lastHeaders.Get("session_id"))
|
||||||
|
require.Equal(t, "turn-state-1", captureDialer.lastHeaders.Get(openAIWSTurnStateHeader))
|
||||||
|
require.Equal(t, "turn-meta-1", captureDialer.lastHeaders.Get(openAIWSTurnMetadataHeader))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@ -727,6 +727,70 @@ func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheK
|
|||||||
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_Forward_WSv2_ResponseDoneUsageParsed(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||||
|
c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.done","response":{"id":"resp_done_usage","model":"gpt-5.1","usage":{"input_tokens":13,"output_tokens":8,"input_tokens_details":{"cached_tokens":5},"cache_creation_input_tokens":2,"output_tokens_details":{"image_tokens":4}}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 32,
|
||||||
|
Name: "openai-ws-done",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-test",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hi"}]}`)
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, "resp_done_usage", result.RequestID)
|
||||||
|
require.Equal(t, 13, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||||
|
require.Equal(t, 5, result.Usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 2, result.Usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, 4, result.Usage.ImageOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
75
backend/internal/service/openai_ws_forwarder_test.go
Normal file
75
backend/internal/service/openai_ws_forwarder_test.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestIsOpenAIWSTokenEvent_TerminalEventsExcluded 覆盖 isOpenAIWSTokenEvent 的回归用例。
|
||||||
|
// 重点验证终止事件(response.completed / response.done)不再被当作 token event,
|
||||||
|
// 否则当上游没有可识别的 delta 时,firstTokenMs 会被填到终止时刻,
|
||||||
|
// 等于把"总耗时"误报为"首 token 延迟"(issue #2651)。
|
||||||
|
func TestIsOpenAIWSTokenEvent_TerminalEventsExcluded(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
eventType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "empty", eventType: "", want: false},
|
||||||
|
{name: "whitespace_trimmed_empty", eventType: " ", want: false},
|
||||||
|
|
||||||
|
{name: "response.created", eventType: "response.created", want: false},
|
||||||
|
{name: "response.in_progress", eventType: "response.in_progress", want: false},
|
||||||
|
{name: "response.output_item.added", eventType: "response.output_item.added", want: false},
|
||||||
|
{name: "response.output_item.done", eventType: "response.output_item.done", want: false},
|
||||||
|
|
||||||
|
{name: "terminal_response.completed", eventType: "response.completed", want: false},
|
||||||
|
{name: "terminal_response.done", eventType: "response.done", want: false},
|
||||||
|
{name: "terminal_response.completed_padded", eventType: " response.completed ", want: false},
|
||||||
|
{name: "terminal_response.done_padded", eventType: " response.done ", want: false},
|
||||||
|
|
||||||
|
{name: "delta_text", eventType: "response.output_text.delta", want: true},
|
||||||
|
{name: "delta_audio_transcript", eventType: "response.audio_transcript.delta", want: true},
|
||||||
|
{name: "delta_function_call_arguments", eventType: "response.function_call_arguments.delta", want: true},
|
||||||
|
|
||||||
|
{name: "output_text_done", eventType: "response.output_text.done", want: true},
|
||||||
|
{name: "output_text_annotation_added", eventType: "response.output_text.annotation.added", want: true},
|
||||||
|
|
||||||
|
{name: "output_audio_done", eventType: "response.output_audio.done", want: true},
|
||||||
|
|
||||||
|
{name: "reasoning_summary_delta", eventType: "response.reasoning_summary_text.delta", want: true},
|
||||||
|
|
||||||
|
{name: "unrelated_event_error", eventType: "error", want: false},
|
||||||
|
{name: "unknown_event_without_match", eventType: "response.reasoning_summary_part.added", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := isOpenAIWSTokenEvent(tc.eventType)
|
||||||
|
require.Equal(t, tc.want, got, "isOpenAIWSTokenEvent(%q)", tc.eventType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsOpenAIWSTokenEvent_DisjointWithTerminal 守护「token 事件集合与终止事件集合互斥」的不变量。
|
||||||
|
// firstTokenMs 的计算依赖于 isTokenEvent && !isTerminalEvent;
|
||||||
|
// 若两者再次出现交集,则 issue #2651 描述的 latency 误报会重现。
|
||||||
|
func TestIsOpenAIWSTokenEvent_DisjointWithTerminal(t *testing.T) {
|
||||||
|
terminalEvents := []string{
|
||||||
|
"response.completed",
|
||||||
|
"response.done",
|
||||||
|
"response.failed",
|
||||||
|
"response.incomplete",
|
||||||
|
"response.cancelled",
|
||||||
|
"response.canceled",
|
||||||
|
}
|
||||||
|
for _, ev := range terminalEvents {
|
||||||
|
ev := ev
|
||||||
|
t.Run(ev, func(t *testing.T) {
|
||||||
|
require.True(t, isOpenAIWSTerminalEvent(ev), "expected terminal event %q to be classified as terminal", ev)
|
||||||
|
require.False(t, isOpenAIWSTokenEvent(ev), "terminal event %q must NOT be classified as token event (issue #2651)", ev)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -25,6 +25,7 @@ type Usage struct {
|
|||||||
OutputTokens int
|
OutputTokens int
|
||||||
CacheCreationInputTokens int
|
CacheCreationInputTokens int
|
||||||
CacheReadInputTokens int
|
CacheReadInputTokens int
|
||||||
|
ImageOutputTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelayResult struct {
|
type RelayResult struct {
|
||||||
@ -756,8 +757,21 @@ func parseUsageAndAccumulate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
||||||
|
if !inputResult.Exists() {
|
||||||
|
inputResult = gjson.GetBytes(message, "response.usage.prompt_tokens")
|
||||||
|
}
|
||||||
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
||||||
|
if !outputResult.Exists() {
|
||||||
|
outputResult = gjson.GetBytes(message, "response.usage.completion_tokens")
|
||||||
|
}
|
||||||
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
||||||
|
if !cachedResult.Exists() {
|
||||||
|
cachedResult = gjson.GetBytes(message, "response.usage.prompt_tokens_details.cached_tokens")
|
||||||
|
}
|
||||||
|
imageTokens := usageResult.Get("output_tokens_details.image_tokens").Int()
|
||||||
|
if imageTokens == 0 {
|
||||||
|
imageTokens = usageResult.Get("completion_tokens_details.image_tokens").Int()
|
||||||
|
}
|
||||||
|
|
||||||
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
||||||
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
||||||
@ -771,14 +785,18 @@ func parseUsageAndAccumulate(
|
|||||||
return Usage{}
|
return Usage{}
|
||||||
}
|
}
|
||||||
parsedUsage := Usage{
|
parsedUsage := Usage{
|
||||||
InputTokens: inputTokens,
|
InputTokens: inputTokens,
|
||||||
OutputTokens: outputTokens,
|
OutputTokens: outputTokens,
|
||||||
CacheReadInputTokens: cachedTokens,
|
CacheCreationInputTokens: int(usageResult.Get("cache_creation_input_tokens").Int()),
|
||||||
|
CacheReadInputTokens: cachedTokens,
|
||||||
|
ImageOutputTokens: int(imageTokens),
|
||||||
}
|
}
|
||||||
|
|
||||||
state.usage.InputTokens += parsedUsage.InputTokens
|
state.usage.InputTokens += parsedUsage.InputTokens
|
||||||
state.usage.OutputTokens += parsedUsage.OutputTokens
|
state.usage.OutputTokens += parsedUsage.OutputTokens
|
||||||
|
state.usage.CacheCreationInputTokens += parsedUsage.CacheCreationInputTokens
|
||||||
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
||||||
|
state.usage.ImageOutputTokens += parsedUsage.ImageOutputTokens
|
||||||
return parsedUsage
|
return parsedUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -840,7 +858,7 @@ func isTerminalEvent(eventType string) bool {
|
|||||||
|
|
||||||
func shouldParseUsage(eventType string) bool {
|
func shouldParseUsage(eventType string) bool {
|
||||||
switch eventType {
|
switch eventType {
|
||||||
case "response.completed", "response.done", "response.failed":
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -300,20 +300,41 @@ func TestParseUsageAndEnrichCoverage(t *testing.T) {
|
|||||||
require.Equal(t, 0, state.usage.OutputTokens)
|
require.Equal(t, 0, state.usage.OutputTokens)
|
||||||
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
require.Equal(t, 0, state.usage.CacheReadInputTokens)
|
||||||
|
|
||||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1},"cache_creation_input_tokens":4,"output_tokens_details":{"image_tokens":3}}}}`), "response.completed", nil)
|
||||||
require.Equal(t, 2, state.usage.InputTokens)
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
require.Equal(t, 1, state.usage.OutputTokens)
|
require.Equal(t, 1, state.usage.OutputTokens)
|
||||||
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
require.Equal(t, 1, state.usage.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 4, state.usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, 3, state.usage.ImageOutputTokens)
|
||||||
|
|
||||||
result := &RelayResult{}
|
result := &RelayResult{}
|
||||||
enrichResult(result, state, 5*time.Millisecond)
|
enrichResult(result, state, 5*time.Millisecond)
|
||||||
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
|
||||||
|
require.Equal(t, state.usage.CacheCreationInputTokens, result.Usage.CacheCreationInputTokens)
|
||||||
|
require.Equal(t, state.usage.ImageOutputTokens, result.Usage.ImageOutputTokens)
|
||||||
require.Equal(t, 5*time.Millisecond, result.Duration)
|
require.Equal(t, 5*time.Millisecond, result.Duration)
|
||||||
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
|
||||||
require.Equal(t, 2, state.usage.InputTokens)
|
require.Equal(t, 2, state.usage.InputTokens)
|
||||||
enrichResult(nil, state, 0)
|
enrichResult(nil, state, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseUsageAndAccumulateAcceptsChatUsageAliases(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
state := &relayState{}
|
||||||
|
got := parseUsageAndAccumulate(
|
||||||
|
state,
|
||||||
|
[]byte(`{"type":"response.done","response":{"usage":{"prompt_tokens":12,"completion_tokens":6,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"image_tokens":2}}}}`),
|
||||||
|
"response.done",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.Equal(t, 12, got.InputTokens)
|
||||||
|
require.Equal(t, 6, got.OutputTokens)
|
||||||
|
require.Equal(t, 4, got.CacheReadInputTokens)
|
||||||
|
require.Equal(t, 2, got.ImageOutputTokens)
|
||||||
|
require.Equal(t, got, state.usage)
|
||||||
|
}
|
||||||
|
|
||||||
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
func TestEmitTurnCompleteCoverage(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -377,6 +398,23 @@ func TestIsTokenEventCoverageBranches(t *testing.T) {
|
|||||||
require.True(t, isTokenEvent("response.done"))
|
require.True(t, isTokenEvent("response.done"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldParseUsageTerminalEvents(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, eventType := range []string{
|
||||||
|
"response.completed",
|
||||||
|
"response.done",
|
||||||
|
"response.failed",
|
||||||
|
"response.incomplete",
|
||||||
|
"response.cancelled",
|
||||||
|
"response.canceled",
|
||||||
|
} {
|
||||||
|
require.True(t, shouldParseUsage(eventType), eventType)
|
||||||
|
}
|
||||||
|
require.False(t, shouldParseUsage("response.output_text.delta"))
|
||||||
|
require.False(t, shouldParseUsage(""))
|
||||||
|
}
|
||||||
|
|
||||||
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@ -312,6 +312,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||||
// goroutine)之间同步当前 turn 的 usage metadata。
|
// goroutine)之间同步当前 turn 的 usage metadata。
|
||||||
usageMeta.initFromFirstFrame(firstClientMessage)
|
usageMeta.initFromFirstFrame(firstClientMessage)
|
||||||
|
promptCacheKey := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "prompt_cache_key").String())
|
||||||
|
|
||||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -338,7 +339,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||||
isCodexCLI = true
|
isCodexCLI = true
|
||||||
}
|
}
|
||||||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
turnState := ""
|
||||||
|
turnMetadata := ""
|
||||||
|
if c != nil {
|
||||||
|
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
|
||||||
|
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
|
||||||
|
}
|
||||||
|
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
@ -519,6 +526,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
OutputTokens: turn.Usage.OutputTokens,
|
OutputTokens: turn.Usage.OutputTokens,
|
||||||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||||
|
ImageOutputTokens: turn.Usage.ImageOutputTokens,
|
||||||
},
|
},
|
||||||
Model: turn.RequestModel,
|
Model: turn.RequestModel,
|
||||||
ServiceTier: usageMeta.serviceTier.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
@ -593,6 +601,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
OutputTokens: relayResult.Usage.OutputTokens,
|
OutputTokens: relayResult.Usage.OutputTokens,
|
||||||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||||
|
ImageOutputTokens: relayResult.Usage.ImageOutputTokens,
|
||||||
},
|
},
|
||||||
Model: relayResult.RequestModel,
|
Model: relayResult.RequestModel,
|
||||||
ServiceTier: usageMeta.serviceTier.Load(),
|
ServiceTier: usageMeta.serviceTier.Load(),
|
||||||
|
|||||||
@ -41,6 +41,11 @@ type OpsService struct {
|
|||||||
// cleanupReloader 由 wire 在 OpsCleanupService 构造完成后通过 SetCleanupReloader 注入。
|
// cleanupReloader 由 wire 在 OpsCleanupService 构造完成后通过 SetCleanupReloader 注入。
|
||||||
// 解耦避免 OpsService -> OpsCleanupService 的硬依赖(cleanup 也读 settings,会循环)。
|
// 解耦避免 OpsService -> OpsCleanupService 的硬依赖(cleanup 也读 settings,会循环)。
|
||||||
cleanupReloader CleanupReloader
|
cleanupReloader CleanupReloader
|
||||||
|
|
||||||
|
// quotaAutoPauseSink 由 wire 注入(通常是 SettingService.SetOpenAIQuotaAutoPauseSettings)。
|
||||||
|
// UpdateOpsAdvancedSettings 写入新配置后调用,把最新的 quota auto-pause 全局默认阈值
|
||||||
|
// 立即同步到调度热路径读取的内存缓存,避免下次请求才能感知新值。
|
||||||
|
quotaAutoPauseSink func(OpsOpenAIAccountQuotaAutoPauseSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupReloader 由 OpsCleanupService 实现。
|
// CleanupReloader 由 OpsCleanupService 实现。
|
||||||
@ -57,6 +62,16 @@ func (s *OpsService) SetCleanupReloader(r CleanupReloader) {
|
|||||||
s.cleanupReloader = r
|
s.cleanupReloader = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOpenAIQuotaAutoPauseSettingsSink 由 wire 注入,把最新的 quota auto-pause 全局默认
|
||||||
|
// 阈值 push 到调度热路径读取的内存缓存。同 SetCleanupReloader 的解耦目的:避免 OpsService
|
||||||
|
// 持有 *SettingService 引入循环依赖。
|
||||||
|
func (s *OpsService) SetOpenAIQuotaAutoPauseSettingsSink(sink func(OpsOpenAIAccountQuotaAutoPauseSettings)) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.quotaAutoPauseSink = sink
|
||||||
|
}
|
||||||
|
|
||||||
func NewOpsService(
|
func NewOpsService(
|
||||||
opsRepo OpsRepository,
|
opsRepo OpsRepository,
|
||||||
settingRepo SettingRepository,
|
settingRepo SettingRepository,
|
||||||
|
|||||||
@ -369,6 +369,7 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
|||||||
Aggregation: OpsAggregationSettings{
|
Aggregation: OpsAggregationSettings{
|
||||||
AggregationEnabled: false,
|
AggregationEnabled: false,
|
||||||
},
|
},
|
||||||
|
OpenAIAccountQuotaAutoPause: OpsOpenAIAccountQuotaAutoPauseSettings{},
|
||||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||||
@ -384,6 +385,8 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
|
|||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold5h = clampOpsQuotaAutoPauseThreshold(cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold5h)
|
||||||
|
cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold7d = clampOpsQuotaAutoPauseThreshold(cfg.OpenAIAccountQuotaAutoPause.DefaultThreshold7d)
|
||||||
cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule)
|
cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule)
|
||||||
if cfg.DataRetention.CleanupSchedule == "" {
|
if cfg.DataRetention.CleanupSchedule == "" {
|
||||||
cfg.DataRetention.CleanupSchedule = opsCleanupDefaultSchedule
|
cfg.DataRetention.CleanupSchedule = opsCleanupDefaultSchedule
|
||||||
@ -405,6 +408,16 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func clampOpsQuotaAutoPauseThreshold(value float64) float64 {
|
||||||
|
if value <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if value > 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
|
func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return errors.New("invalid config")
|
return errors.New("invalid config")
|
||||||
@ -477,6 +490,12 @@ func (s *OpsService) UpdateOpsAdvancedSettings(ctx context.Context, cfg *OpsAdva
|
|||||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(raw)); err != nil {
|
if err := s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(raw)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Push the new quota auto-pause settings straight into the in-memory cache that
|
||||||
|
// the OpenAI scheduling hot path reads, so the next request observes the new value
|
||||||
|
// without waiting for the background refresher's TTL.
|
||||||
|
if s.quotaAutoPauseSink != nil {
|
||||||
|
s.quotaAutoPauseSink(cfg.OpenAIAccountQuotaAutoPause)
|
||||||
|
}
|
||||||
|
|
||||||
// notify cleanup service to reload schedule/enabled.
|
// notify cleanup service to reload schedule/enabled.
|
||||||
if s.cleanupReloader != nil {
|
if s.cleanupReloader != nil {
|
||||||
|
|||||||
@ -4,6 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
|
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
|
||||||
@ -95,3 +98,64 @@ func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.
|
|||||||
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
|
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOpenAIQuotaAutoPauseSettings_ReadsDefaultsFromOpsAdvancedSettings(t *testing.T) {
|
||||||
|
repo := newRuntimeSettingRepoStub()
|
||||||
|
repo.values[SettingKeyOpsAdvancedSettings] = `{"openai_account_quota_auto_pause":{"default_threshold_5h":0.95,"default_threshold_7d":0.9}}`
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
// Warm the in-memory cache synchronously so the assertion below is deterministic.
|
||||||
|
// GetOpenAIQuotaAutoPauseSettings is non-blocking on the hot path (returns the
|
||||||
|
// cached value, refreshes asynchronously); for tests and startup, Warm is the
|
||||||
|
// synchronous entry point that guarantees a populated cache.
|
||||||
|
settings := svc.WarmOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
if settings.DefaultThreshold5h != 0.95 {
|
||||||
|
t.Fatalf("DefaultThreshold5h = %v, want 0.95", settings.DefaultThreshold5h)
|
||||||
|
}
|
||||||
|
if settings.DefaultThreshold7d != 0.9 {
|
||||||
|
t.Fatalf("DefaultThreshold7d = %v, want 0.9", settings.DefaultThreshold7d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent Get must hit the warm cache and return the same value without any DB
|
||||||
|
// access — that's the hot-path invariant.
|
||||||
|
cached := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
if cached.DefaultThreshold5h != 0.95 || cached.DefaultThreshold7d != 0.9 {
|
||||||
|
t.Fatalf("cached read = %+v, want {0.95, 0.9}", cached)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hot-path invariant: a Get with cold cache must return immediately (zero defaults)
|
||||||
|
// rather than blocking on the DB. The async refresher will populate the cache for
|
||||||
|
// subsequent calls.
|
||||||
|
func TestGetOpenAIQuotaAutoPauseSettings_ColdCacheNonBlocking(t *testing.T) {
|
||||||
|
repo := newRuntimeSettingRepoStub()
|
||||||
|
repo.values[SettingKeyOpsAdvancedSettings] = `{"openai_account_quota_auto_pause":{"default_threshold_5h":0.7}}`
|
||||||
|
svc := NewSettingService(repo, &config.Config{})
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
settings := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if elapsed > 50*time.Millisecond {
|
||||||
|
t.Fatalf("cold-cache Get must be non-blocking, took %v", elapsed)
|
||||||
|
}
|
||||||
|
// Cold cache means we get zero defaults (the async refresh hasn't completed yet).
|
||||||
|
if settings.DefaultThreshold5h != 0 || settings.DefaultThreshold7d != 0 {
|
||||||
|
t.Fatalf("cold-cache Get = %+v, want zeroes", settings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit cache write (e.g. from UpdateOpsAdvancedSettings) must be visible on the
|
||||||
|
// very next read without any DB roundtrip.
|
||||||
|
func TestSetOpenAIQuotaAutoPauseSettings_VisibleImmediately(t *testing.T) {
|
||||||
|
svc := NewSettingService(newRuntimeSettingRepoStub(), &config.Config{})
|
||||||
|
|
||||||
|
svc.SetOpenAIQuotaAutoPauseSettings(OpsOpenAIAccountQuotaAutoPauseSettings{
|
||||||
|
DefaultThreshold5h: 0.88,
|
||||||
|
DefaultThreshold7d: 0.77,
|
||||||
|
})
|
||||||
|
|
||||||
|
got := svc.GetOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
if got.DefaultThreshold5h != 0.88 || got.DefaultThreshold7d != 0.77 {
|
||||||
|
t.Fatalf("after Set, Get = %+v, want {0.88, 0.77}", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -92,17 +92,23 @@ type OpsAlertRuntimeSettings struct {
|
|||||||
|
|
||||||
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
||||||
type OpsAdvancedSettings struct {
|
type OpsAdvancedSettings struct {
|
||||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
OpenAIAccountQuotaAutoPause OpsOpenAIAccountQuotaAutoPauseSettings `json:"openai_account_quota_auto_pause"`
|
||||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||||
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
|
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
|
||||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||||
|
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpsOpenAIAccountQuotaAutoPauseSettings struct {
|
||||||
|
DefaultThreshold5h float64 `json:"default_threshold_5h"`
|
||||||
|
DefaultThreshold7d float64 `json:"default_threshold_7d"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpsDataRetentionSettings struct {
|
type OpsDataRetentionSettings struct {
|
||||||
|
|||||||
@ -137,10 +137,32 @@ type cachedOpenAICodexUserAgent struct {
|
|||||||
expiresAt int64 // unix nano
|
expiresAt int64 // unix nano
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cachedOpenAIQuotaAutoPauseSettings struct {
|
||||||
|
settings OpsOpenAIAccountQuotaAutoPauseSettings
|
||||||
|
expiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
const openAICodexUserAgentCacheTTL = 60 * time.Second
|
const openAICodexUserAgentCacheTTL = 60 * time.Second
|
||||||
const openAICodexUserAgentErrorTTL = 5 * time.Second
|
const openAICodexUserAgentErrorTTL = 5 * time.Second
|
||||||
const openAICodexUserAgentDBTimeout = 5 * time.Second
|
const openAICodexUserAgentDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// cachedOpenAIAllowCodexPlugin Codex 插件放行开关缓存(进程内缓存,60s TTL)。
|
||||||
|
// IsOpenAIAllowClaudeCodeCodexPluginEnabled 在每个 codex_cli_only 账号的网关请求热路径上被调用,避免每次访问 DB。
|
||||||
|
type cachedOpenAIAllowCodexPlugin struct {
|
||||||
|
value bool
|
||||||
|
expiresAt int64 // unix nano
|
||||||
|
}
|
||||||
|
|
||||||
|
const openAIAllowCodexPluginCacheTTL = 60 * time.Second
|
||||||
|
const openAIAllowCodexPluginErrorTTL = 5 * time.Second
|
||||||
|
const openAIAllowCodexPluginDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const openAIQuotaAutoPauseSettingsCacheTTL = 60 * time.Second
|
||||||
|
const openAIQuotaAutoPauseSettingsErrorTTL = 5 * time.Second
|
||||||
|
const openAIQuotaAutoPauseSettingsDBTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const openAIQuotaAutoPauseSettingsRefreshKey = "openai_quota_auto_pause_settings"
|
||||||
|
|
||||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||||
type DefaultSubscriptionGroupReader interface {
|
type DefaultSubscriptionGroupReader interface {
|
||||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||||
@ -152,17 +174,28 @@ type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[i
|
|||||||
|
|
||||||
// SettingService 系统设置服务
|
// SettingService 系统设置服务
|
||||||
type SettingService struct {
|
type SettingService struct {
|
||||||
settingRepo SettingRepository
|
settingRepo SettingRepository
|
||||||
defaultSubGroupReader DefaultSubscriptionGroupReader
|
defaultSubGroupReader DefaultSubscriptionGroupReader
|
||||||
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
|
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||||
version string // Application version
|
version string // Application version
|
||||||
webSearchManagerBuilder WebSearchManagerBuilder
|
webSearchManagerBuilder WebSearchManagerBuilder
|
||||||
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
antigravityUAVersionCache atomic.Value // *cachedAntigravityUserAgentVersion
|
||||||
antigravityUAVersionSF singleflight.Group
|
antigravityUAVersionSF singleflight.Group
|
||||||
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
|
openAICodexUACache atomic.Value // *cachedOpenAICodexUserAgent
|
||||||
openAICodexUASF singleflight.Group
|
openAICodexUASF singleflight.Group
|
||||||
|
openAIAllowCodexPluginCache atomic.Value // *cachedOpenAIAllowCodexPlugin
|
||||||
|
openAIAllowCodexPluginSF singleflight.Group
|
||||||
|
|
||||||
|
// openAIQuotaAutoPauseSettingsCache holds the most recently observed quota auto-pause
|
||||||
|
// settings. GetOpenAIQuotaAutoPauseSettings reads this atomic.Value on the request hot
|
||||||
|
// path without ever blocking on the DB; when the cached entry expires, a background
|
||||||
|
// goroutine refreshes it via openAIQuotaAutoPauseSettingsSF (stale-while-revalidate).
|
||||||
|
// This per-service field also gives tests natural isolation — each SettingService
|
||||||
|
// instance owns its own cache, no shared package-level state.
|
||||||
|
openAIQuotaAutoPauseSettingsCache atomic.Value // *cachedOpenAIQuotaAutoPauseSettings
|
||||||
|
openAIQuotaAutoPauseSettingsSF singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultPlatformQuotaSetting 单 platform 三档限额(nil = 沿用上层;0 = 显式禁用;>0 = 上限)
|
// DefaultPlatformQuotaSetting 单 platform 三档限额(nil = 沿用上层;0 = 显式禁用;>0 = 上限)
|
||||||
@ -1015,6 +1048,54 @@ func (s *SettingService) GetOpenAICodexUserAgent(ctx context.Context) string {
|
|||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOpenAIAllowClaudeCodeCodexPluginEnabled 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认关闭)。
|
||||||
|
// 仅在调用方已确认账号 codex_cli_only 开启时读取,避免对非受限账号产生无谓查询。
|
||||||
|
// 使用进程内 atomic.Value 缓存(60s TTL),避免在每个网关请求热路径上访问 DB。
|
||||||
|
func (s *SettingService) IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx context.Context) bool {
|
||||||
|
if cached, ok := s.openAIAllowCodexPluginCache.Load().(*cachedOpenAIAllowCodexPlugin); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result, _, _ := s.openAIAllowCodexPluginSF.Do("openai_allow_codex_plugin_enabled", func() (any, error) {
|
||||||
|
if cached, ok := s.openAIAllowCodexPluginCache.Load().(*cachedOpenAIAllowCodexPlugin); ok && cached != nil {
|
||||||
|
if time.Now().UnixNano() < cached.expiresAt {
|
||||||
|
return cached.value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAllowCodexPluginDBTimeout)
|
||||||
|
defer cancel()
|
||||||
|
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpenAIAllowClaudeCodeCodexPlugin)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSettingNotFound) {
|
||||||
|
// 设置不存在 → 默认关闭,正常 TTL 缓存
|
||||||
|
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
|
||||||
|
value: false,
|
||||||
|
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
slog.Warn("failed to get openai_allow_claude_code_codex_plugin setting", "error", err)
|
||||||
|
// DB 错误 → 安全默认关闭,短 TTL 快速重试
|
||||||
|
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
|
||||||
|
value: false,
|
||||||
|
expiresAt: time.Now().Add(openAIAllowCodexPluginErrorTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
enabled := value == "true"
|
||||||
|
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
|
||||||
|
value: enabled,
|
||||||
|
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
return enabled, nil
|
||||||
|
})
|
||||||
|
if val, ok := result.(bool); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
||||||
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
||||||
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||||
@ -1816,6 +1897,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
|||||||
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl)
|
||||||
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion)
|
||||||
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
|
updates[SettingKeyOpenAICodexUserAgent] = strings.TrimSpace(settings.OpenAICodexUserAgent)
|
||||||
|
updates[SettingKeyOpenAIAllowClaudeCodeCodexPlugin] = strconv.FormatBool(settings.OpenAIAllowClaudeCodeCodexPlugin)
|
||||||
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
|
||||||
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
|
||||||
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
|
||||||
@ -1965,9 +2047,25 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
|||||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||||
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
||||||
})
|
})
|
||||||
|
// Invalidate the quota auto-pause cache and let the next read trigger a fresh load.
|
||||||
|
// We can't know from here whether ops_advanced_settings was also touched, so be
|
||||||
|
// defensive: store an expired entry — GetOpenAIQuotaAutoPauseSettings will serve
|
||||||
|
// stale and kick off an async refresh, never blocking the request that follows.
|
||||||
|
s.openAIQuotaAutoPauseSettingsSF.Forget(openAIQuotaAutoPauseSettingsRefreshKey)
|
||||||
|
if cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings); cached != nil {
|
||||||
|
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
|
||||||
|
settings: cached.settings,
|
||||||
|
expiresAt: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP)
|
s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP)
|
||||||
}
|
}
|
||||||
|
s.openAIAllowCodexPluginSF.Forget("openai_allow_codex_plugin_enabled")
|
||||||
|
s.openAIAllowCodexPluginCache.Store(&cachedOpenAIAllowCodexPlugin{
|
||||||
|
value: settings.OpenAIAllowClaudeCodeCodexPlugin,
|
||||||
|
expiresAt: time.Now().Add(openAIAllowCodexPluginCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
if s.onUpdate != nil {
|
if s.onUpdate != nil {
|
||||||
s.onUpdate() // Invalidate cache after settings update
|
s.onUpdate() // Invalidate cache after settings update
|
||||||
}
|
}
|
||||||
@ -3233,6 +3331,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
|||||||
}
|
}
|
||||||
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion])
|
||||||
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
|
result.OpenAICodexUserAgent = strings.TrimSpace(settings[SettingKeyOpenAICodexUserAgent])
|
||||||
|
result.OpenAIAllowClaudeCodeCodexPlugin = settings[SettingKeyOpenAIAllowClaudeCodeCodexPlugin] == "true"
|
||||||
|
|
||||||
// Web search emulation: quick enabled check from the JSON config
|
// Web search emulation: quick enabled check from the JSON config
|
||||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||||
@ -4380,6 +4479,106 @@ func (s *SettingService) GetClaudeCodeVersionBounds(ctx context.Context) (min, m
|
|||||||
return b.min, b.max
|
return b.min, b.max
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIQuotaAutoPauseSettings returns the current global default quota auto-pause
|
||||||
|
// settings. It is invoked on the OpenAI scheduling hot path (once per request) and is
|
||||||
|
// therefore designed to never block on the DB:
|
||||||
|
//
|
||||||
|
// - Fresh cached value → returned immediately.
|
||||||
|
// - Stale or empty cache → the last known value is returned, and a background
|
||||||
|
// goroutine refreshes the cache via singleflight (stale-while-revalidate).
|
||||||
|
// - First call with no cache yet → zero defaults are returned and the same async
|
||||||
|
// refresh is kicked off; the next call gets the freshly populated value.
|
||||||
|
//
|
||||||
|
// Callers that need the freshly persisted value synchronously (tests, post-update
|
||||||
|
// confirmation, optional startup warm-up) should call WarmOpenAIQuotaAutoPauseSettings.
|
||||||
|
func (s *SettingService) GetOpenAIQuotaAutoPauseSettings(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
|
||||||
|
if s == nil {
|
||||||
|
return OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
}
|
||||||
|
cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings)
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
if cached != nil && now < cached.expiresAt {
|
||||||
|
return cached.settings
|
||||||
|
}
|
||||||
|
// Stale or unset: trigger background refresh without blocking this request.
|
||||||
|
// singleflight.DoChan dedupes concurrent refreshes; we deliberately ignore the
|
||||||
|
// returned channel — the result is observable via the atomic cache.
|
||||||
|
s.openAIQuotaAutoPauseSettingsSF.DoChan(openAIQuotaAutoPauseSettingsRefreshKey, func() (any, error) {
|
||||||
|
s.refreshOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
if cached != nil {
|
||||||
|
return cached.settings // serve stale value while revalidating
|
||||||
|
}
|
||||||
|
return OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WarmOpenAIQuotaAutoPauseSettings synchronously loads the quota auto-pause settings
|
||||||
|
// into the in-memory cache. Useful for application startup (so the first request hits
|
||||||
|
// a warm cache) and for tests that need deterministic reads immediately after
|
||||||
|
// constructing the service.
|
||||||
|
func (s *SettingService) WarmOpenAIQuotaAutoPauseSettings(ctx context.Context) OpsOpenAIAccountQuotaAutoPauseSettings {
|
||||||
|
if s == nil {
|
||||||
|
return OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
}
|
||||||
|
s.refreshOpenAIQuotaAutoPauseSettings(ctx)
|
||||||
|
cached, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings)
|
||||||
|
if cached == nil {
|
||||||
|
return OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
}
|
||||||
|
return cached.settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshOpenAIQuotaAutoPauseSettings reads the latest settings from the DB and stores
|
||||||
|
// them into the in-memory cache. On error it stores the prior value (or zero defaults
|
||||||
|
// if nothing is cached yet) with the shorter error TTL so the next refresh comes
|
||||||
|
// sooner. Always uses its own timeout-bounded context to keep refresh latency
|
||||||
|
// predictable regardless of the caller.
|
||||||
|
func (s *SettingService) refreshOpenAIQuotaAutoPauseSettings(ctx context.Context) {
|
||||||
|
if s == nil || s.settingRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIQuotaAutoPauseSettingsDBTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
settings := OpsOpenAIAccountQuotaAutoPauseSettings{}
|
||||||
|
ttl := openAIQuotaAutoPauseSettingsCacheTTL
|
||||||
|
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyOpsAdvancedSettings)
|
||||||
|
if err == nil {
|
||||||
|
cfg := defaultOpsAdvancedSettings()
|
||||||
|
if strings.TrimSpace(raw) != "" {
|
||||||
|
if jsonErr := json.Unmarshal([]byte(raw), cfg); jsonErr == nil {
|
||||||
|
normalizeOpsAdvancedSettings(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
settings = cfg.OpenAIAccountQuotaAutoPause
|
||||||
|
} else if !errors.Is(err, ErrSettingNotFound) {
|
||||||
|
// Real error: keep serving prior value but refresh sooner.
|
||||||
|
if prior, _ := s.openAIQuotaAutoPauseSettingsCache.Load().(*cachedOpenAIQuotaAutoPauseSettings); prior != nil {
|
||||||
|
settings = prior.settings
|
||||||
|
}
|
||||||
|
ttl = openAIQuotaAutoPauseSettingsErrorTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
|
||||||
|
settings: settings,
|
||||||
|
expiresAt: time.Now().Add(ttl).UnixNano(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOpenAIQuotaAutoPauseSettings writes the given settings directly into the in-memory
|
||||||
|
// cache. Called from settings-write code paths so that the next read reflects the new
|
||||||
|
// value immediately, without waiting for the background refresh.
|
||||||
|
func (s *SettingService) SetOpenAIQuotaAutoPauseSettings(settings OpsOpenAIAccountQuotaAutoPauseSettings) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.openAIQuotaAutoPauseSettingsCache.Store(&cachedOpenAIQuotaAutoPauseSettings{
|
||||||
|
settings: settings,
|
||||||
|
expiresAt: time.Now().Add(openAIQuotaAutoPauseSettingsCacheTTL).UnixNano(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// GetRectifierSettings 获取请求整流器配置
|
// GetRectifierSettings 获取请求整流器配置
|
||||||
func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) {
|
func (s *SettingService) GetRectifierSettings(ctx context.Context) (*RectifierSettings, error) {
|
||||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings)
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyRectifierSettings)
|
||||||
|
|||||||
@ -0,0 +1,55 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type allowClaudeCodeSettingRepoStub struct{ values map[string]string }
|
||||||
|
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
if v, ok := s.values[key]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
func (s *allowClaudeCodeSettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unused")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingService_IsOpenAIAllowClaudeCodeCodexPluginEnabled(t *testing.T) {
|
||||||
|
t.Run("默认关闭(设置缺失)", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{}}, &config.Config{})
|
||||||
|
require.False(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
|
||||||
|
})
|
||||||
|
t.Run("值为 true 时开启", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyOpenAIAllowClaudeCodeCodexPlugin: "true",
|
||||||
|
}}, &config.Config{})
|
||||||
|
require.True(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
|
||||||
|
})
|
||||||
|
t.Run("值非 true 时关闭", func(t *testing.T) {
|
||||||
|
svc := NewSettingService(&allowClaudeCodeSettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyOpenAIAllowClaudeCodeCodexPlugin: "false",
|
||||||
|
}}, &config.Config{})
|
||||||
|
require.False(t, svc.IsOpenAIAllowClaudeCodeCodexPluginEnabled(context.Background()))
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -195,6 +195,7 @@ type SystemSettings struct {
|
|||||||
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
RewriteMessageCacheControl bool // 是否改写 messages[*].content[*].cache_control(默认 false)
|
||||||
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
AntigravityUserAgentVersion string // Antigravity 上游 User-Agent 版本号;空值使用配置/默认值
|
||||||
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent;空值使用内置默认
|
OpenAICodexUserAgent string // OpenAI Codex 上游完整 User-Agent;空值使用内置默认
|
||||||
|
OpenAIAllowClaudeCodeCodexPlugin bool // 全局开关:是否额外放行 Claude Code 的 Codex 插件(默认 false)
|
||||||
|
|
||||||
// Web Search Emulation
|
// Web Search Emulation
|
||||||
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
|
||||||
|
|||||||
@ -17,6 +17,12 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoUpdateAvailable = infraerrors.Conflict("ALREADY_UP_TO_DATE", "no update available; current version is latest")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -146,7 +152,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !info.HasUpdate {
|
if !info.HasUpdate {
|
||||||
return fmt.Errorf("no update available")
|
return ErrNoUpdateAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find matching archive and checksum for current platform
|
// Find matching archive and checksum for current platform
|
||||||
|
|||||||
64
backend/internal/service/update_service_test.go
Normal file
64
backend/internal/service/update_service_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type updateServiceCacheStub struct {
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *updateServiceCacheStub) GetUpdateInfo(context.Context) (string, error) {
|
||||||
|
if s.data == "" {
|
||||||
|
return "", errors.New("cache miss")
|
||||||
|
}
|
||||||
|
return s.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *updateServiceCacheStub) SetUpdateInfo(_ context.Context, data string, _ time.Duration) error {
|
||||||
|
s.data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateServiceGitHubClientStub struct {
|
||||||
|
release *GitHubRelease
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *updateServiceGitHubClientStub) FetchLatestRelease(context.Context, string) (*GitHubRelease, error) {
|
||||||
|
return s.release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *updateServiceGitHubClientStub) DownloadFile(context.Context, string, string, int64) error {
|
||||||
|
panic("DownloadFile should not be called when no update is available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *updateServiceGitHubClientStub) FetchChecksumFile(context.Context, string) ([]byte, error) {
|
||||||
|
panic("FetchChecksumFile should not be called when no update is available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateServicePerformUpdateNoUpdateReturnsSentinel(t *testing.T) {
|
||||||
|
svc := NewUpdateService(
|
||||||
|
&updateServiceCacheStub{},
|
||||||
|
&updateServiceGitHubClientStub{
|
||||||
|
release: &GitHubRelease{
|
||||||
|
TagName: "v0.1.132",
|
||||||
|
Name: "v0.1.132",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"0.1.132",
|
||||||
|
"release",
|
||||||
|
)
|
||||||
|
|
||||||
|
err := svc.PerformUpdate(context.Background())
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, ErrNoUpdateAvailable))
|
||||||
|
require.ErrorIs(t, err, ErrNoUpdateAvailable)
|
||||||
|
}
|
||||||
@ -399,6 +399,46 @@ func ProvideBackupService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideOpsService constructs OpsService and wires the SettingService-backed quota
|
||||||
|
// auto-pause cache sink. Mirrors the SetCleanupReloader pattern: OpsService doesn't
|
||||||
|
// hold a *SettingService reference, but wire injects a tiny callback so writes to
|
||||||
|
// ops_advanced_settings immediately propagate into the scheduler hot-path cache.
|
||||||
|
func ProvideOpsService(
|
||||||
|
opsRepo OpsRepository,
|
||||||
|
settingRepo SettingRepository,
|
||||||
|
cfg *config.Config,
|
||||||
|
accountRepo AccountRepository,
|
||||||
|
userRepo UserRepository,
|
||||||
|
concurrencyService *ConcurrencyService,
|
||||||
|
gatewayService *GatewayService,
|
||||||
|
openAIGatewayService *OpenAIGatewayService,
|
||||||
|
geminiCompatService *GeminiMessagesCompatService,
|
||||||
|
antigravityGatewayService *AntigravityGatewayService,
|
||||||
|
systemLogSink *OpsSystemLogSink,
|
||||||
|
settingService *SettingService,
|
||||||
|
) *OpsService {
|
||||||
|
svc := NewOpsService(
|
||||||
|
opsRepo,
|
||||||
|
settingRepo,
|
||||||
|
cfg,
|
||||||
|
accountRepo,
|
||||||
|
userRepo,
|
||||||
|
concurrencyService,
|
||||||
|
gatewayService,
|
||||||
|
openAIGatewayService,
|
||||||
|
geminiCompatService,
|
||||||
|
antigravityGatewayService,
|
||||||
|
systemLogSink,
|
||||||
|
)
|
||||||
|
if settingService != nil {
|
||||||
|
svc.SetOpenAIQuotaAutoPauseSettingsSink(settingService.SetOpenAIQuotaAutoPauseSettings)
|
||||||
|
// Optional warm-up so the first scheduled request after process start observes
|
||||||
|
// a populated cache rather than zero defaults. Best-effort, sync-bounded.
|
||||||
|
settingService.WarmOpenAIQuotaAutoPauseSettings(context.Background())
|
||||||
|
}
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideSettingService wires SettingService with group reader and proxy repo.
|
// ProvideSettingService wires SettingService with group reader and proxy repo.
|
||||||
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
|
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
|
||||||
svc := NewSettingService(settingRepo, cfg)
|
svc := NewSettingService(settingRepo, cfg)
|
||||||
@ -486,7 +526,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideBackupService,
|
ProvideBackupService,
|
||||||
NewHealthService,
|
NewHealthService,
|
||||||
ProvideOpsSystemLogSink,
|
ProvideOpsSystemLogSink,
|
||||||
NewOpsService,
|
ProvideOpsService,
|
||||||
ProvideOpsMetricsCollector,
|
ProvideOpsMetricsCollector,
|
||||||
ProvideOpsAggregationService,
|
ProvideOpsAggregationService,
|
||||||
ProvideOpsAlertEvaluatorService,
|
ProvideOpsAlertEvaluatorService,
|
||||||
|
|||||||
16
backend/migrations/144_add_opus48_to_model_mapping.sql
Normal file
16
backend/migrations/144_add_opus48_to_model_mapping.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
-- 为已持久化的 Antigravity model_mapping 添加 claude-opus-4-8。
|
||||||
|
--
|
||||||
|
-- 未持久化 model_mapping 的账号会直接使用 DefaultAntigravityModelMapping,
|
||||||
|
-- 因此这里只需要回填已有映射对象。
|
||||||
|
|
||||||
|
UPDATE accounts
|
||||||
|
SET credentials = jsonb_set(
|
||||||
|
credentials,
|
||||||
|
'{model_mapping,claude-opus-4-8}',
|
||||||
|
'"claude-opus-4-8"'::jsonb,
|
||||||
|
true
|
||||||
|
)
|
||||||
|
WHERE platform = 'antigravity'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND jsonb_typeof(credentials->'model_mapping') = 'object'
|
||||||
|
AND credentials->'model_mapping'->>'claude-opus-4-8' IS NULL;
|
||||||
File diff suppressed because it is too large
Load Diff
@ -778,9 +778,15 @@ export interface OpsAlertRuntimeSettings {
|
|||||||
thresholds: OpsMetricThresholds // 指标阈值配置
|
thresholds: OpsMetricThresholds // 指标阈值配置
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface OpsOpenAIAccountQuotaAutoPauseSettings {
|
||||||
|
default_threshold_5h: number // 0~1,0 表示不启用全局默认 5h 阈值
|
||||||
|
default_threshold_7d: number // 0~1,0 表示不启用全局默认 7d 阈值
|
||||||
|
}
|
||||||
|
|
||||||
export interface OpsAdvancedSettings {
|
export interface OpsAdvancedSettings {
|
||||||
data_retention: OpsDataRetentionSettings
|
data_retention: OpsDataRetentionSettings
|
||||||
aggregation: OpsAggregationSettings
|
aggregation: OpsAggregationSettings
|
||||||
|
openai_account_quota_auto_pause: OpsOpenAIAccountQuotaAutoPauseSettings
|
||||||
ignore_count_tokens_errors: boolean
|
ignore_count_tokens_errors: boolean
|
||||||
ignore_context_canceled: boolean
|
ignore_context_canceled: boolean
|
||||||
ignore_no_available_accounts: boolean
|
ignore_no_available_accounts: boolean
|
||||||
|
|||||||
@ -561,6 +561,7 @@ export interface SystemSettings {
|
|||||||
rewrite_message_cache_control: boolean;
|
rewrite_message_cache_control: boolean;
|
||||||
antigravity_user_agent_version: string;
|
antigravity_user_agent_version: string;
|
||||||
openai_codex_user_agent: string;
|
openai_codex_user_agent: string;
|
||||||
|
openai_allow_claude_code_codex_plugin: boolean;
|
||||||
web_search_emulation_enabled?: boolean;
|
web_search_emulation_enabled?: boolean;
|
||||||
|
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
@ -794,6 +795,7 @@ export interface UpdateSettingsRequest {
|
|||||||
rewrite_message_cache_control?: boolean;
|
rewrite_message_cache_control?: boolean;
|
||||||
antigravity_user_agent_version?: string;
|
antigravity_user_agent_version?: string;
|
||||||
openai_codex_user_agent?: string;
|
openai_codex_user_agent?: string;
|
||||||
|
openai_allow_claude_code_codex_plugin?: boolean;
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
payment_enabled?: boolean;
|
payment_enabled?: boolean;
|
||||||
risk_control_enabled?: boolean;
|
risk_control_enabled?: boolean;
|
||||||
|
|||||||
@ -222,6 +222,8 @@ const formatScopeName = (scope: string): string => {
|
|||||||
// Claude 系列
|
// Claude 系列
|
||||||
'claude-opus-4-6': 'COpus46',
|
'claude-opus-4-6': 'COpus46',
|
||||||
'claude-opus-4-6-thinking': 'COpus46T',
|
'claude-opus-4-6-thinking': 'COpus46T',
|
||||||
|
'claude-opus-4-7': 'COpus47',
|
||||||
|
'claude-opus-4-8': 'COpus48',
|
||||||
'claude-sonnet-4-6': 'CSon46',
|
'claude-sonnet-4-6': 'CSon46',
|
||||||
'claude-sonnet-4-5': 'CSon45',
|
'claude-sonnet-4-5': 'CSon45',
|
||||||
'claude-sonnet-4-5-thinking': 'CSon45T',
|
'claude-sonnet-4-5-thinking': 'CSon45T',
|
||||||
|
|||||||
@ -697,6 +697,7 @@ const antigravityClaudeUsageFromAPI = computed(() =>
|
|||||||
getAntigravityUsageFromAPI([
|
getAntigravityUsageFromAPI([
|
||||||
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
|
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
|
||||||
'claude-sonnet-4-6', 'claude-opus-4-6', 'claude-opus-4-6-thinking',
|
'claude-sonnet-4-6', 'claude-opus-4-6', 'claude-opus-4-6-thinking',
|
||||||
|
'claude-opus-4-7', 'claude-opus-4-8',
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -742,6 +742,50 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- OpenAI OAuth: 额外放行 Claude Code 的 Codex 插件 -->
|
||||||
|
<div v-if="allOpenAIOAuth" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
|
<div class="mb-3 flex items-center justify-between">
|
||||||
|
<label
|
||||||
|
id="bulk-edit-openai-codex-allow-claude-code-label"
|
||||||
|
class="input-label mb-0"
|
||||||
|
for="bulk-edit-openai-codex-allow-claude-code-enabled"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
v-model="enableCodexCLIOnlyAllowClaudeCode"
|
||||||
|
id="bulk-edit-openai-codex-allow-claude-code-enabled"
|
||||||
|
type="checkbox"
|
||||||
|
aria-controls="bulk-edit-openai-codex-allow-claude-code"
|
||||||
|
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="bulk-edit-openai-codex-allow-claude-code"
|
||||||
|
:class="!enableCodexCLIOnlyAllowClaudeCode && 'pointer-events-none opacity-50'"
|
||||||
|
>
|
||||||
|
<p class="mb-3 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
|
||||||
|
</p>
|
||||||
|
<button
|
||||||
|
id="bulk-edit-openai-codex-allow-claude-code-toggle"
|
||||||
|
type="button"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- OpenAI API Key WS mode -->
|
<!-- OpenAI API Key WS mode -->
|
||||||
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
<div v-if="allOpenAIAPIKey" class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||||
<div class="mb-3 flex items-center justify-between">
|
<div class="mb-3 flex items-center justify-between">
|
||||||
@ -1219,6 +1263,7 @@ const enableOpenAIPassthrough = ref(false)
|
|||||||
const enableOpenAIWSMode = ref(false)
|
const enableOpenAIWSMode = ref(false)
|
||||||
const enableOpenAIAPIKeyWSMode = ref(false)
|
const enableOpenAIAPIKeyWSMode = ref(false)
|
||||||
const enableCodexCLIOnly = ref(false)
|
const enableCodexCLIOnly = ref(false)
|
||||||
|
const enableCodexCLIOnlyAllowClaudeCode = ref(false)
|
||||||
const enableOpenAICompactMode = ref(false)
|
const enableOpenAICompactMode = ref(false)
|
||||||
const enableOpenAICompactModelMapping = ref(false)
|
const enableOpenAICompactModelMapping = ref(false)
|
||||||
const enableRpmLimit = ref(false)
|
const enableRpmLimit = ref(false)
|
||||||
@ -1246,6 +1291,7 @@ const openaiPassthroughEnabled = ref(false)
|
|||||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
|
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
|
||||||
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
||||||
const openAICompactModelMappings = ref<ModelMapping[]>([])
|
const openAICompactModelMappings = ref<ModelMapping[]>([])
|
||||||
const rpmLimitEnabled = ref(false)
|
const rpmLimitEnabled = ref(false)
|
||||||
@ -1496,6 +1542,11 @@ const buildUpdatePayload = (): Record<string, unknown> | null => {
|
|||||||
extra.codex_cli_only = codexCLIOnlyEnabled.value
|
extra.codex_cli_only = codexCLIOnlyEnabled.value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enableCodexCLIOnlyAllowClaudeCode.value) {
|
||||||
|
const extra = ensureExtra()
|
||||||
|
extra.codex_cli_only_allowed_clients = codexCLIOnlyAllowClaudeCodeEnabled.value ? ['claude_code'] : []
|
||||||
|
}
|
||||||
|
|
||||||
if (enableOpenAICompactMode.value) {
|
if (enableOpenAICompactMode.value) {
|
||||||
const extra = ensureExtra()
|
const extra = ensureExtra()
|
||||||
extra.openai_compact_mode = openAICompactMode.value
|
extra.openai_compact_mode = openAICompactMode.value
|
||||||
@ -1602,6 +1653,7 @@ const handleSubmit = async () => {
|
|||||||
enableOpenAIWSMode.value ||
|
enableOpenAIWSMode.value ||
|
||||||
enableOpenAIAPIKeyWSMode.value ||
|
enableOpenAIAPIKeyWSMode.value ||
|
||||||
enableCodexCLIOnly.value ||
|
enableCodexCLIOnly.value ||
|
||||||
|
enableCodexCLIOnlyAllowClaudeCode.value ||
|
||||||
enableOpenAICompactMode.value ||
|
enableOpenAICompactMode.value ||
|
||||||
enableOpenAICompactModelMapping.value ||
|
enableOpenAICompactModelMapping.value ||
|
||||||
enableRpmLimit.value ||
|
enableRpmLimit.value ||
|
||||||
@ -1704,6 +1756,7 @@ watch(
|
|||||||
enableOpenAIWSMode.value = false
|
enableOpenAIWSMode.value = false
|
||||||
enableOpenAIAPIKeyWSMode.value = false
|
enableOpenAIAPIKeyWSMode.value = false
|
||||||
enableCodexCLIOnly.value = false
|
enableCodexCLIOnly.value = false
|
||||||
|
enableCodexCLIOnlyAllowClaudeCode.value = false
|
||||||
enableOpenAICompactMode.value = false
|
enableOpenAICompactMode.value = false
|
||||||
enableOpenAICompactModelMapping.value = false
|
enableOpenAICompactModelMapping.value = false
|
||||||
enableRpmLimit.value = false
|
enableRpmLimit.value = false
|
||||||
@ -1727,6 +1780,7 @@ watch(
|
|||||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value = false
|
||||||
openAICompactMode.value = 'auto'
|
openAICompactMode.value = 'auto'
|
||||||
openAICompactModelMappings.value = []
|
openAICompactModelMappings.value = []
|
||||||
rpmLimitEnabled.value = false
|
rpmLimitEnabled.value = false
|
||||||
|
|||||||
@ -2746,6 +2746,32 @@
|
|||||||
/>
|
/>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="codexCLIOnlyEnabled"
|
||||||
|
class="mt-4 flex items-center justify-between border-l-2 border-gray-200 pl-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- OpenAI Compact 能力配置 -->
|
<!-- OpenAI Compact 能力配置 -->
|
||||||
@ -2790,7 +2816,7 @@
|
|||||||
<!-- OpenAI APIKey Responses API support mode -->
|
<!-- OpenAI APIKey Responses API support mode -->
|
||||||
<div
|
<div
|
||||||
v-if="form.platform === 'openai' && accountCategory === 'apikey'"
|
v-if="form.platform === 'openai' && accountCategory === 'apikey'"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600"
|
class="space-y-4 border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-center justify-between gap-4">
|
<div class="flex items-center justify-between gap-4">
|
||||||
<div>
|
<div>
|
||||||
@ -2803,10 +2829,38 @@
|
|||||||
<Select
|
<Select
|
||||||
v-model="openAIResponsesMode"
|
v-model="openAIResponsesMode"
|
||||||
:options="openAIResponsesModeOptions"
|
:options="openAIResponsesModeOptions"
|
||||||
|
:disabled="!openAITextGenerationCapabilityEnabled"
|
||||||
data-testid="openai-responses-mode-select"
|
data-testid="openai-responses-mode-select"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<p
|
||||||
|
v-if="!openAITextGenerationCapabilityEnabled"
|
||||||
|
class="rounded-lg bg-amber-50 px-3 py-2 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300"
|
||||||
|
data-testid="openai-responses-mode-not-applicable"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.openai.responsesModeTextDisabledHint') }}
|
||||||
|
</p>
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-2 block">{{ t('admin.accounts.openai.endpointCapabilities') }}</label>
|
||||||
|
<div class="grid grid-cols-1 gap-2 sm:grid-cols-2">
|
||||||
|
<label
|
||||||
|
v-for="option in openAIEndpointCapabilityOptions"
|
||||||
|
:key="option.value"
|
||||||
|
class="flex cursor-pointer items-center gap-2 rounded-lg border border-gray-200 px-3 py-2 text-sm dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||||
|
:data-testid="`openai-endpoint-capability-${option.value}`"
|
||||||
|
:checked="openAIEndpointCapabilities.includes(option.value)"
|
||||||
|
@change="toggleOpenAIEndpointCapability(option.value, $event)"
|
||||||
|
/>
|
||||||
|
<span class="text-gray-700 dark:text-gray-200">{{ option.label }}</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
@ -3287,7 +3341,8 @@ import type {
|
|||||||
CreateAccountRequest,
|
CreateAccountRequest,
|
||||||
CodexSessionImportMessage,
|
CodexSessionImportMessage,
|
||||||
OpenAICompactMode,
|
OpenAICompactMode,
|
||||||
OpenAIResponsesMode
|
OpenAIResponsesMode,
|
||||||
|
OpenAIEndpointCapability
|
||||||
} from '@/types'
|
} from '@/types'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||||
@ -3466,9 +3521,11 @@ const autoPauseOnExpired = ref(true)
|
|||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
||||||
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
|
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
|
||||||
|
const openAIEndpointCapabilities = ref<OpenAIEndpointCapability[]>(['chat_completions', 'embeddings'])
|
||||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
|
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
|
||||||
const anthropicPassthroughEnabled = ref(false)
|
const anthropicPassthroughEnabled = ref(false)
|
||||||
const webSearchEmulationMode = ref('default')
|
const webSearchEmulationMode = ref('default')
|
||||||
const webSearchGlobalEnabled = ref(false)
|
const webSearchGlobalEnabled = ref(false)
|
||||||
@ -3534,6 +3591,58 @@ const openAIResponsesModeOptions = computed(() => [
|
|||||||
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
|
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
|
||||||
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
|
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
|
||||||
])
|
])
|
||||||
|
const openAITextEndpointCapabilityLabel = computed(() => {
|
||||||
|
if (openAIResponsesMode.value === 'force_responses') {
|
||||||
|
return t('admin.accounts.openai.capabilityResponses')
|
||||||
|
}
|
||||||
|
if (openAIResponsesMode.value === 'force_chat_completions') {
|
||||||
|
return t('admin.accounts.openai.capabilityChatCompletions')
|
||||||
|
}
|
||||||
|
return t('admin.accounts.openai.capabilityTextAuto')
|
||||||
|
})
|
||||||
|
const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [
|
||||||
|
{ value: 'chat_completions', label: openAITextEndpointCapabilityLabel.value },
|
||||||
|
{ value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') }
|
||||||
|
])
|
||||||
|
const openAITextGenerationCapabilityEnabled = computed(() =>
|
||||||
|
openAIEndpointCapabilities.value.includes('chat_completions')
|
||||||
|
)
|
||||||
|
|
||||||
|
const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => {
|
||||||
|
const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings']
|
||||||
|
const selected = allowed.filter((value) => values.includes(value))
|
||||||
|
return selected.length > 0 ? selected : allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => {
|
||||||
|
if (openAIEndpointCapabilities.value.includes(capability)) {
|
||||||
|
if (openAIEndpointCapabilities.value.length <= 1) {
|
||||||
|
const input = event?.target as HTMLInputElement | null
|
||||||
|
if (input) input.checked = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter(
|
||||||
|
(value) => value !== capability
|
||||||
|
)
|
||||||
|
if (!openAITextGenerationCapabilityEnabled.value) {
|
||||||
|
openAIResponsesMode.value = 'auto'
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([
|
||||||
|
...openAIEndpointCapabilities.value,
|
||||||
|
capability
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
const applyOpenAIEndpointCapabilities = (credentials: Record<string, unknown>) => {
|
||||||
|
const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value)
|
||||||
|
if (capabilities.length === 2) {
|
||||||
|
delete credentials.openai_capabilities
|
||||||
|
return
|
||||||
|
}
|
||||||
|
credentials.openai_capabilities = capabilities
|
||||||
|
}
|
||||||
|
|
||||||
function buildAntigravityExtra(): Record<string, unknown> | undefined {
|
function buildAntigravityExtra(): Record<string, unknown> | undefined {
|
||||||
const extra: Record<string, unknown> = {}
|
const extra: Record<string, unknown> = {}
|
||||||
@ -3847,9 +3956,11 @@ watch(
|
|||||||
}
|
}
|
||||||
if (newPlatform !== 'openai') {
|
if (newPlatform !== 'openai') {
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
|
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
|
||||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value = false
|
||||||
}
|
}
|
||||||
if (newPlatform !== 'anthropic') {
|
if (newPlatform !== 'anthropic') {
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
@ -3870,6 +3981,7 @@ watch(
|
|||||||
([category, platform]) => {
|
([category, platform]) => {
|
||||||
if (platform === 'openai' && category !== 'oauth-based') {
|
if (platform === 'openai' && category !== 'oauth-based') {
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value = false
|
||||||
}
|
}
|
||||||
if (platform !== 'anthropic' || category !== 'apikey') {
|
if (platform !== 'anthropic' || category !== 'apikey') {
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
@ -4268,9 +4380,11 @@ const resetForm = () => {
|
|||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
openAICompactMode.value = 'auto'
|
openAICompactMode.value = 'auto'
|
||||||
openAIResponsesMode.value = 'auto'
|
openAIResponsesMode.value = 'auto'
|
||||||
|
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
|
||||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value = false
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
webSearchEmulationMode.value = 'default'
|
webSearchEmulationMode.value = 'default'
|
||||||
// Reset quota control state
|
// Reset quota control state
|
||||||
@ -4353,13 +4467,26 @@ const buildOpenAIExtra = (base?: Record<string, unknown>): Record<string, unknow
|
|||||||
} else {
|
} else {
|
||||||
delete extra.codex_cli_only
|
delete extra.codex_cli_only
|
||||||
}
|
}
|
||||||
|
if (
|
||||||
|
accountCategory.value === 'oauth-based' &&
|
||||||
|
codexCLIOnlyEnabled.value &&
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value
|
||||||
|
) {
|
||||||
|
extra.codex_cli_only_allowed_clients = ['claude_code']
|
||||||
|
} else {
|
||||||
|
delete extra.codex_cli_only_allowed_clients
|
||||||
|
}
|
||||||
if (openAICompactMode.value !== 'auto') {
|
if (openAICompactMode.value !== 'auto') {
|
||||||
extra.openai_compact_mode = openAICompactMode.value
|
extra.openai_compact_mode = openAICompactMode.value
|
||||||
} else {
|
} else {
|
||||||
delete extra.openai_compact_mode
|
delete extra.openai_compact_mode
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accountCategory.value === 'apikey' && openAIResponsesMode.value !== 'auto') {
|
if (
|
||||||
|
accountCategory.value === 'apikey' &&
|
||||||
|
openAITextGenerationCapabilityEnabled.value &&
|
||||||
|
openAIResponsesMode.value !== 'auto'
|
||||||
|
) {
|
||||||
extra.openai_responses_mode = openAIResponsesMode.value
|
extra.openai_responses_mode = openAIResponsesMode.value
|
||||||
} else {
|
} else {
|
||||||
delete extra.openai_responses_mode
|
delete extra.openai_responses_mode
|
||||||
@ -4689,6 +4816,7 @@ const handleSubmit = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (form.platform === 'openai') {
|
if (form.platform === 'openai') {
|
||||||
|
applyOpenAIEndpointCapabilities(credentials)
|
||||||
const compactModelMapping = buildOpenAICompactModelMapping()
|
const compactModelMapping = buildOpenAICompactModelMapping()
|
||||||
if (compactModelMapping) {
|
if (compactModelMapping) {
|
||||||
credentials.compact_model_mapping = compactModelMapping
|
credentials.compact_model_mapping = compactModelMapping
|
||||||
@ -4811,6 +4939,9 @@ const createAccountAndFinish = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (platform === 'openai') {
|
if (platform === 'openai') {
|
||||||
|
if (type === 'apikey') {
|
||||||
|
applyOpenAIEndpointCapabilities(credentials)
|
||||||
|
}
|
||||||
const compactModelMapping = buildOpenAICompactModelMapping()
|
const compactModelMapping = buildOpenAICompactModelMapping()
|
||||||
if (compactModelMapping) {
|
if (compactModelMapping) {
|
||||||
credentials.compact_model_mapping = compactModelMapping
|
credentials.compact_model_mapping = compactModelMapping
|
||||||
|
|||||||
@ -1439,7 +1439,7 @@
|
|||||||
<!-- OpenAI APIKey Responses API support mode -->
|
<!-- OpenAI APIKey Responses API support mode -->
|
||||||
<div
|
<div
|
||||||
v-if="account?.platform === 'openai' && account?.type === 'apikey'"
|
v-if="account?.platform === 'openai' && account?.type === 'apikey'"
|
||||||
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-3"
|
class="space-y-4 border-t border-gray-200 pt-4 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-center justify-between gap-4">
|
<div class="flex items-center justify-between gap-4">
|
||||||
<div>
|
<div>
|
||||||
@ -1452,13 +1452,44 @@
|
|||||||
<Select
|
<Select
|
||||||
v-model="openAIResponsesMode"
|
v-model="openAIResponsesMode"
|
||||||
:options="openAIResponsesModeOptions"
|
:options="openAIResponsesModeOptions"
|
||||||
|
:disabled="!openAITextGenerationCapabilityEnabled"
|
||||||
data-testid="openai-responses-mode-select"
|
data-testid="openai-responses-mode-select"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="rounded-lg bg-gray-50 px-3 py-2 text-xs text-gray-600 dark:bg-dark-700 dark:text-gray-300">
|
<div
|
||||||
|
v-if="openAITextGenerationCapabilityEnabled"
|
||||||
|
class="rounded-lg bg-gray-50 px-3 py-2 text-xs text-gray-600 dark:bg-dark-700 dark:text-gray-300"
|
||||||
|
>
|
||||||
<span class="font-medium">{{ t(openAIResponsesStatusKey) }}</span>
|
<span class="font-medium">{{ t(openAIResponsesStatusKey) }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-else
|
||||||
|
class="rounded-lg bg-amber-50 px-3 py-2 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300"
|
||||||
|
data-testid="openai-responses-mode-not-applicable"
|
||||||
|
>
|
||||||
|
{{ t('admin.accounts.openai.responsesModeTextDisabledHint') }}
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-2 block">{{ t('admin.accounts.openai.endpointCapabilities') }}</label>
|
||||||
|
<div class="grid grid-cols-1 gap-2 sm:grid-cols-2">
|
||||||
|
<label
|
||||||
|
v-for="option in openAIEndpointCapabilityOptions"
|
||||||
|
:key="option.value"
|
||||||
|
class="flex cursor-pointer items-center gap-2 rounded-lg border border-gray-200 px-3 py-2 text-sm dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||||
|
:data-testid="`openai-endpoint-capability-${option.value}`"
|
||||||
|
:checked="openAIEndpointCapabilities.includes(option.value)"
|
||||||
|
@change="toggleOpenAIEndpointCapability(option.value, $event)"
|
||||||
|
/>
|
||||||
|
<span class="text-gray-700 dark:text-gray-200">{{ option.label }}</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Anthropic API Key 自动透传开关 -->
|
<!-- Anthropic API Key 自动透传开关 -->
|
||||||
@ -1642,6 +1673,32 @@
|
|||||||
/>
|
/>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="codexCLIOnlyEnabled"
|
||||||
|
class="mt-4 flex items-center justify-between border-l-2 border-gray-200 pl-4 dark:border-dark-600"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCode') }}</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t('admin.accounts.openai.codexCLIOnlyAllowClaudeCodeDesc') }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="codexCLIOnlyAllowClaudeCodeEnabled = !codexCLIOnlyAllowClaudeCodeEnabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
@ -1730,6 +1787,84 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="account?.platform === 'openai'"
|
||||||
|
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
|
||||||
|
>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.autoPause5hDisabled') }}</label>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="autoPause5hDisabled = !autoPause5hDisabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
autoPause5hDisabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
data-testid="auto-pause-5h-disabled"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
autoPause5hDisabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.autoPauseDisabledHint') }}</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.accounts.autoPause5hThreshold') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="autoPause5hThreshold"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="100"
|
||||||
|
step="0.1"
|
||||||
|
class="input"
|
||||||
|
:disabled="autoPause5hDisabled"
|
||||||
|
data-testid="auto-pause-5h-threshold"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.autoPauseThresholdHint') }}</p>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<label class="input-label mb-0">{{ t('admin.accounts.autoPause7dDisabled') }}</label>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="autoPause7dDisabled = !autoPause7dDisabled"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||||
|
autoPause7dDisabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
data-testid="auto-pause-7d-disabled"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
|
autoPause7dDisabled ? 'translate-x-5' : 'translate-x-0'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.autoPauseDisabledHint') }}</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.accounts.autoPause7dThreshold') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="autoPause7dThreshold"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="100"
|
||||||
|
step="0.1"
|
||||||
|
class="input"
|
||||||
|
:disabled="autoPause7dDisabled"
|
||||||
|
data-testid="auto-pause-7d-threshold"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.accounts.autoPauseThresholdHint') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 配额控制 (Anthropic OAuth/SetupToken: 亲和 + 窗口费用 + 会话 + RPM 等) -->
|
<!-- 配额控制 (Anthropic OAuth/SetupToken: 亲和 + 窗口费用 + 会话 + RPM 等) -->
|
||||||
<div
|
<div
|
||||||
v-if="account?.platform === 'anthropic' && (account?.type === 'oauth' || account?.type === 'setup-token')"
|
v-if="account?.platform === 'anthropic' && (account?.type === 'oauth' || account?.type === 'setup-token')"
|
||||||
@ -2245,7 +2380,15 @@ import { useAppStore } from '@/stores/app'
|
|||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
|
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
|
||||||
import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode, OpenAIResponsesMode } from '@/types'
|
import type {
|
||||||
|
Account,
|
||||||
|
Proxy,
|
||||||
|
AdminGroup,
|
||||||
|
CheckMixedChannelResponse,
|
||||||
|
OpenAICompactMode,
|
||||||
|
OpenAIResponsesMode,
|
||||||
|
OpenAIEndpointCapability
|
||||||
|
} from '@/types'
|
||||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
@ -2382,6 +2525,10 @@ const selectedErrorCodes = ref<number[]>([])
|
|||||||
const customErrorCodeInput = ref<number | null>(null)
|
const customErrorCodeInput = ref<number | null>(null)
|
||||||
const interceptWarmupRequests = ref(false)
|
const interceptWarmupRequests = ref(false)
|
||||||
const autoPauseOnExpired = ref(false)
|
const autoPauseOnExpired = ref(false)
|
||||||
|
const autoPause5hThreshold = ref<number | null>(null)
|
||||||
|
const autoPause7dThreshold = ref<number | null>(null)
|
||||||
|
const autoPause5hDisabled = ref(false)
|
||||||
|
const autoPause7dDisabled = ref(false)
|
||||||
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
|
||||||
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
|
||||||
const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||||
@ -2433,9 +2580,11 @@ const customBaseUrl = ref('')
|
|||||||
const openaiPassthroughEnabled = ref(false)
|
const openaiPassthroughEnabled = ref(false)
|
||||||
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
const openAICompactMode = ref<OpenAICompactMode>('auto')
|
||||||
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
|
const openAIResponsesMode = ref<OpenAIResponsesMode>('auto')
|
||||||
|
const openAIEndpointCapabilities = ref<OpenAIEndpointCapability[]>(['chat_completions', 'embeddings'])
|
||||||
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiOAuthResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
const openaiAPIKeyResponsesWebSocketV2Mode = ref<OpenAIWSMode>(OPENAI_WS_MODE_OFF)
|
||||||
const codexCLIOnlyEnabled = ref(false)
|
const codexCLIOnlyEnabled = ref(false)
|
||||||
|
const codexCLIOnlyAllowClaudeCodeEnabled = ref(false)
|
||||||
type CodexImageGenerationBridgeMode = 'inherit' | 'enabled' | 'disabled'
|
type CodexImageGenerationBridgeMode = 'inherit' | 'enabled' | 'disabled'
|
||||||
const codexImageGenerationBridgeMode = ref<CodexImageGenerationBridgeMode>('inherit')
|
const codexImageGenerationBridgeMode = ref<CodexImageGenerationBridgeMode>('inherit')
|
||||||
const anthropicPassthroughEnabled = ref(false)
|
const anthropicPassthroughEnabled = ref(false)
|
||||||
@ -2539,6 +2688,85 @@ const openAIResponsesModeOptions = computed(() => [
|
|||||||
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
|
{ value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
|
||||||
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
|
{ value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
|
||||||
])
|
])
|
||||||
|
const openAITextEndpointCapabilityLabel = computed(() => {
|
||||||
|
if (openAIResponsesMode.value === 'force_responses') {
|
||||||
|
return t('admin.accounts.openai.capabilityResponses')
|
||||||
|
}
|
||||||
|
if (openAIResponsesMode.value === 'force_chat_completions') {
|
||||||
|
return t('admin.accounts.openai.capabilityChatCompletions')
|
||||||
|
}
|
||||||
|
const extra = props.account?.extra as Record<string, unknown> | undefined
|
||||||
|
if (extra?.openai_responses_supported === true) {
|
||||||
|
return t('admin.accounts.openai.capabilityResponsesAuto')
|
||||||
|
}
|
||||||
|
if (extra?.openai_responses_supported === false) {
|
||||||
|
return t('admin.accounts.openai.capabilityChatCompletionsAuto')
|
||||||
|
}
|
||||||
|
return t('admin.accounts.openai.capabilityTextAuto')
|
||||||
|
})
|
||||||
|
const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [
|
||||||
|
{ value: 'chat_completions', label: openAITextEndpointCapabilityLabel.value },
|
||||||
|
{ value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') }
|
||||||
|
])
|
||||||
|
const openAITextGenerationCapabilityEnabled = computed(() =>
|
||||||
|
openAIEndpointCapabilities.value.includes('chat_completions')
|
||||||
|
)
|
||||||
|
|
||||||
|
const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => {
|
||||||
|
const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings']
|
||||||
|
const selected = allowed.filter((value) => values.includes(value))
|
||||||
|
return selected.length > 0 ? selected : allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
const readOpenAIEndpointCapabilities = (credentials?: Record<string, unknown>): OpenAIEndpointCapability[] => {
|
||||||
|
const raw = credentials?.openai_capabilities
|
||||||
|
if (Array.isArray(raw)) {
|
||||||
|
return normalizeOpenAIEndpointCapabilities(
|
||||||
|
raw.filter((value): value is OpenAIEndpointCapability =>
|
||||||
|
value === 'chat_completions' || value === 'embeddings'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (raw !== null && typeof raw === 'object') {
|
||||||
|
const capabilityMap = raw as Record<string, unknown>
|
||||||
|
return normalizeOpenAIEndpointCapabilities(
|
||||||
|
openAIEndpointCapabilityOptions.value
|
||||||
|
.map((option) => option.value)
|
||||||
|
.filter((value) => capabilityMap[value] === true)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return ['chat_completions', 'embeddings']
|
||||||
|
}
|
||||||
|
|
||||||
|
const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => {
|
||||||
|
if (openAIEndpointCapabilities.value.includes(capability)) {
|
||||||
|
if (openAIEndpointCapabilities.value.length <= 1) {
|
||||||
|
const input = event?.target as HTMLInputElement | null
|
||||||
|
if (input) input.checked = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter(
|
||||||
|
(value) => value !== capability
|
||||||
|
)
|
||||||
|
if (!openAITextGenerationCapabilityEnabled.value) {
|
||||||
|
openAIResponsesMode.value = 'auto'
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([
|
||||||
|
...openAIEndpointCapabilities.value,
|
||||||
|
capability
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
const applyOpenAIEndpointCapabilities = (credentials: Record<string, unknown>) => {
|
||||||
|
const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value)
|
||||||
|
if (capabilities.length === 2) {
|
||||||
|
delete credentials.openai_capabilities
|
||||||
|
return
|
||||||
|
}
|
||||||
|
credentials.openai_capabilities = capabilities
|
||||||
|
}
|
||||||
const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => {
|
const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => {
|
||||||
if (mode === 'force_responses' || mode === 'force_chat_completions') {
|
if (mode === 'force_responses' || mode === 'force_chat_completions') {
|
||||||
return mode
|
return mode
|
||||||
@ -2716,18 +2944,24 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
// Load mixed scheduling setting (only for antigravity accounts)
|
// Load mixed scheduling setting (only for antigravity accounts)
|
||||||
mixedScheduling.value = false
|
mixedScheduling.value = false
|
||||||
allowOverages.value = false
|
allowOverages.value = false
|
||||||
const extra = newAccount.extra as Record<string, unknown> | undefined
|
const extra = newAccount.extra as Record<string, unknown> | undefined
|
||||||
mixedScheduling.value = extra?.mixed_scheduling === true
|
mixedScheduling.value = extra?.mixed_scheduling === true
|
||||||
allowOverages.value = extra?.allow_overages === true
|
allowOverages.value = extra?.allow_overages === true
|
||||||
|
autoPause5hThreshold.value = typeof extra?.auto_pause_5h_threshold === 'number' ? extra.auto_pause_5h_threshold * 100 : null
|
||||||
|
autoPause7dThreshold.value = typeof extra?.auto_pause_7d_threshold === 'number' ? extra.auto_pause_7d_threshold * 100 : null
|
||||||
|
autoPause5hDisabled.value = extra?.auto_pause_5h_disabled === true
|
||||||
|
autoPause7dDisabled.value = extra?.auto_pause_7d_disabled === true
|
||||||
|
|
||||||
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
|
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
|
||||||
openaiPassthroughEnabled.value = false
|
openaiPassthroughEnabled.value = false
|
||||||
openAICompactMode.value = 'auto'
|
openAICompactMode.value = 'auto'
|
||||||
openAIResponsesMode.value = 'auto'
|
openAIResponsesMode.value = 'auto'
|
||||||
|
openAIEndpointCapabilities.value = ['chat_completions', 'embeddings']
|
||||||
openAICompactModelMappings.value = []
|
openAICompactModelMappings.value = []
|
||||||
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
|
||||||
codexCLIOnlyEnabled.value = false
|
codexCLIOnlyEnabled.value = false
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value = false
|
||||||
codexImageGenerationBridgeMode.value = 'inherit'
|
codexImageGenerationBridgeMode.value = 'inherit'
|
||||||
anthropicPassthroughEnabled.value = false
|
anthropicPassthroughEnabled.value = false
|
||||||
webSearchEmulationMode.value = 'default'
|
webSearchEmulationMode.value = 'default'
|
||||||
@ -2736,6 +2970,12 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
|
openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
|
||||||
if (newAccount.type === 'apikey') {
|
if (newAccount.type === 'apikey') {
|
||||||
openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode)
|
openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode)
|
||||||
|
openAIEndpointCapabilities.value = readOpenAIEndpointCapabilities(
|
||||||
|
newAccount.credentials as Record<string, unknown> | undefined
|
||||||
|
)
|
||||||
|
if (!openAITextGenerationCapabilityEnabled.value) {
|
||||||
|
openAIResponsesMode.value = 'auto'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean'
|
const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean'
|
||||||
? extra.codex_image_generation_bridge
|
? extra.codex_image_generation_bridge
|
||||||
@ -2759,6 +2999,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
|
|||||||
})
|
})
|
||||||
if (newAccount.type === 'oauth') {
|
if (newAccount.type === 'oauth') {
|
||||||
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
|
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
|
||||||
|
codexCLIOnlyAllowClaudeCodeEnabled.value =
|
||||||
|
Array.isArray(extra?.codex_cli_only_allowed_clients) &&
|
||||||
|
(extra.codex_cli_only_allowed_clients as unknown[]).includes('claude_code')
|
||||||
}
|
}
|
||||||
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
const credentials = newAccount.credentials as Record<string, unknown> | undefined
|
||||||
const compactMappings = credentials?.compact_model_mapping as Record<string, string> | undefined
|
const compactMappings = credentials?.compact_model_mapping as Record<string, string> | undefined
|
||||||
@ -3476,6 +3719,7 @@ const handleSubmit = async () => {
|
|||||||
newCredentials.model_mapping = currentCredentials.model_mapping
|
newCredentials.model_mapping = currentCredentials.model_mapping
|
||||||
}
|
}
|
||||||
if (props.account.platform === 'openai') {
|
if (props.account.platform === 'openai') {
|
||||||
|
applyOpenAIEndpointCapabilities(newCredentials)
|
||||||
const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
|
const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
|
||||||
if (compactModelMapping) {
|
if (compactModelMapping) {
|
||||||
newCredentials.compact_model_mapping = compactModelMapping
|
newCredentials.compact_model_mapping = compactModelMapping
|
||||||
@ -3829,9 +4073,9 @@ const handleSubmit = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For OpenAI OAuth/API Key accounts, handle passthrough mode in extra
|
// For OpenAI OAuth/API Key accounts, handle passthrough mode in extra
|
||||||
if (props.account.platform === 'openai' && (props.account.type === 'oauth' || props.account.type === 'apikey')) {
|
if (props.account.platform === 'openai' && (props.account.type === 'oauth' || props.account.type === 'apikey')) {
|
||||||
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
const currentExtra = (props.account.extra as Record<string, unknown>) || {}
|
||||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||||
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
|
const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true
|
||||||
if (props.account.type === 'oauth') {
|
if (props.account.type === 'oauth') {
|
||||||
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value
|
||||||
@ -3853,15 +4097,35 @@ const handleSubmit = async () => {
|
|||||||
} else {
|
} else {
|
||||||
newExtra.openai_compact_mode = openAICompactMode.value
|
newExtra.openai_compact_mode = openAICompactMode.value
|
||||||
}
|
}
|
||||||
if (props.account.type === 'apikey') {
|
if (props.account.type === 'apikey') {
|
||||||
if (openAIResponsesMode.value === 'auto') {
|
if (!openAITextGenerationCapabilityEnabled.value || openAIResponsesMode.value === 'auto') {
|
||||||
delete newExtra.openai_responses_mode
|
delete newExtra.openai_responses_mode
|
||||||
} else {
|
} else {
|
||||||
newExtra.openai_responses_mode = openAIResponsesMode.value
|
newExtra.openai_responses_mode = openAIResponsesMode.value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (autoPause5hThreshold.value != null && autoPause5hThreshold.value > 0) {
|
||||||
|
newExtra.auto_pause_5h_threshold = autoPause5hThreshold.value / 100
|
||||||
|
} else {
|
||||||
|
delete newExtra.auto_pause_5h_threshold
|
||||||
|
}
|
||||||
|
if (autoPause7dThreshold.value != null && autoPause7dThreshold.value > 0) {
|
||||||
|
newExtra.auto_pause_7d_threshold = autoPause7dThreshold.value / 100
|
||||||
|
} else {
|
||||||
|
delete newExtra.auto_pause_7d_threshold
|
||||||
|
}
|
||||||
|
if (autoPause5hDisabled.value) {
|
||||||
|
newExtra.auto_pause_5h_disabled = true
|
||||||
|
} else {
|
||||||
|
delete newExtra.auto_pause_5h_disabled
|
||||||
|
}
|
||||||
|
if (autoPause7dDisabled.value) {
|
||||||
|
newExtra.auto_pause_7d_disabled = true
|
||||||
|
} else {
|
||||||
|
delete newExtra.auto_pause_7d_disabled
|
||||||
|
}
|
||||||
|
|
||||||
delete newExtra.codex_image_generation_bridge_enabled
|
delete newExtra.codex_image_generation_bridge_enabled
|
||||||
if (codexImageGenerationBridgeMode.value === 'inherit') {
|
if (codexImageGenerationBridgeMode.value === 'inherit') {
|
||||||
delete newExtra.codex_image_generation_bridge
|
delete newExtra.codex_image_generation_bridge
|
||||||
} else {
|
} else {
|
||||||
@ -3877,6 +4141,12 @@ const handleSubmit = async () => {
|
|||||||
} else {
|
} else {
|
||||||
delete newExtra.codex_cli_only
|
delete newExtra.codex_cli_only
|
||||||
}
|
}
|
||||||
|
// 仅当 codex_cli_only 开启且子开关开启时写入 Claude Code 插件白名单,否则清除避免孤立字段
|
||||||
|
if (codexCLIOnlyEnabled.value && codexCLIOnlyAllowClaudeCodeEnabled.value) {
|
||||||
|
newExtra.codex_cli_only_allowed_clients = ['claude_code']
|
||||||
|
} else {
|
||||||
|
delete newExtra.codex_cli_only_allowed_clients
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
|
|||||||
@ -197,6 +197,25 @@ describe('BulkEditAccountModal', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('OpenAI OAuth 批量编辑应提交 codex_cli_only_allowed_clients 字段', async () => {
|
||||||
|
const wrapper = mountModal({
|
||||||
|
selectedPlatforms: ['openai'],
|
||||||
|
selectedTypes: ['oauth']
|
||||||
|
})
|
||||||
|
|
||||||
|
await wrapper.get('#bulk-edit-openai-codex-allow-claude-code-enabled').setValue(true)
|
||||||
|
await wrapper.get('#bulk-edit-openai-codex-allow-claude-code-toggle').trigger('click')
|
||||||
|
await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
|
||||||
|
expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
|
||||||
|
extra: {
|
||||||
|
codex_cli_only_allowed_clients: ['claude_code']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
|
it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
|
||||||
const wrapper = mountModal({
|
const wrapper = mountModal({
|
||||||
selectedPlatforms: ['openai'],
|
selectedPlatforms: ['openai'],
|
||||||
|
|||||||
@ -310,6 +310,137 @@ describe('EditAccountModal', () => {
|
|||||||
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('submits OpenAI APIKey endpoint capabilities from credentials', async () => {
|
||||||
|
const account = buildAccount()
|
||||||
|
account.credentials.openai_capabilities = ['chat_completions']
|
||||||
|
updateAccountMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
|
||||||
|
updateAccountMock.mockResolvedValue(account)
|
||||||
|
|
||||||
|
const wrapper = mountModal(account)
|
||||||
|
|
||||||
|
expect(wrapper.findAll('input[type="checkbox"]').some((input) => (input.element as HTMLInputElement).checked)).toBe(true)
|
||||||
|
|
||||||
|
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
expect(updateAccountMock).toHaveBeenCalledTimes(1)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
|
||||||
|
'chat_completions'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('submits OpenAI quota auto-pause thresholds in extra', async () => {
|
||||||
|
const account = buildAccount()
|
||||||
|
account.extra = {
|
||||||
|
auto_pause_5h_threshold: 0.9,
|
||||||
|
auto_pause_7d_threshold: 0.8
|
||||||
|
}
|
||||||
|
updateAccountMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
|
||||||
|
updateAccountMock.mockResolvedValue(account)
|
||||||
|
|
||||||
|
const wrapper = mountModal(account)
|
||||||
|
|
||||||
|
await wrapper.get('[data-testid="auto-pause-5h-threshold"]').setValue('95')
|
||||||
|
await wrapper.get('[data-testid="auto-pause-7d-threshold"]').setValue('96')
|
||||||
|
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
expect(updateAccountMock).toHaveBeenCalledTimes(1)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_5h_threshold).toBe(0.95)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_7d_threshold).toBe(0.96)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('submits OpenAI quota auto-pause disable flag in extra', async () => {
|
||||||
|
// Toggling the per-account disable flag must persist as auto_pause_5h_disabled
|
||||||
|
// so an admin can exempt one account from auto-pause even when a global default
|
||||||
|
// threshold is configured (otherwise leaving the threshold blank would silently
|
||||||
|
// fall back to the global default).
|
||||||
|
const account = buildAccount()
|
||||||
|
updateAccountMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
|
||||||
|
updateAccountMock.mockResolvedValue(account)
|
||||||
|
|
||||||
|
const wrapper = mountModal(account)
|
||||||
|
|
||||||
|
await wrapper.get('[data-testid="auto-pause-5h-disabled"]').trigger('click')
|
||||||
|
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
expect(updateAccountMock).toHaveBeenCalledTimes(1)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_5h_disabled).toBe(true)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.auto_pause_7d_disabled).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps at least one OpenAI APIKey endpoint capability selected', async () => {
|
||||||
|
const account = buildAccount()
|
||||||
|
updateAccountMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
|
||||||
|
updateAccountMock.mockResolvedValue(account)
|
||||||
|
|
||||||
|
const wrapper = mountModal(account)
|
||||||
|
|
||||||
|
const chatCheckbox = wrapper.get<HTMLInputElement>(
|
||||||
|
'[data-testid="openai-endpoint-capability-chat_completions"]'
|
||||||
|
)
|
||||||
|
const embeddingsCheckbox = wrapper.get<HTMLInputElement>(
|
||||||
|
'[data-testid="openai-endpoint-capability-embeddings"]'
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(chatCheckbox.element.checked).toBe(true)
|
||||||
|
expect(embeddingsCheckbox.element.checked).toBe(true)
|
||||||
|
|
||||||
|
await embeddingsCheckbox.setValue(false)
|
||||||
|
|
||||||
|
expect(chatCheckbox.element.checked).toBe(true)
|
||||||
|
expect(embeddingsCheckbox.element.checked).toBe(false)
|
||||||
|
|
||||||
|
await chatCheckbox.setValue(false)
|
||||||
|
|
||||||
|
expect(chatCheckbox.element.checked).toBe(true)
|
||||||
|
expect(embeddingsCheckbox.element.checked).toBe(false)
|
||||||
|
|
||||||
|
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
expect(updateAccountMock).toHaveBeenCalledTimes(1)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
|
||||||
|
'chat_completions'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('disables text generation protocol when only embeddings requests are accepted', async () => {
|
||||||
|
const account = buildAccount()
|
||||||
|
account.credentials.openai_capabilities = ['embeddings']
|
||||||
|
account.extra = {
|
||||||
|
openai_responses_mode: 'force_responses',
|
||||||
|
openai_responses_supported: true
|
||||||
|
}
|
||||||
|
updateAccountMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockReset()
|
||||||
|
checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
|
||||||
|
updateAccountMock.mockResolvedValue(account)
|
||||||
|
|
||||||
|
const wrapper = mountModal(account)
|
||||||
|
|
||||||
|
const responsesModeSelect = wrapper.get<HTMLSelectElement>(
|
||||||
|
'[data-testid="openai-responses-mode-select"]'
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(responsesModeSelect.element.disabled).toBe(true)
|
||||||
|
expect(wrapper.find('[data-testid="openai-responses-mode-not-applicable"]').exists()).toBe(true)
|
||||||
|
|
||||||
|
await wrapper.get('form#edit-account-form').trigger('submit.prevent')
|
||||||
|
|
||||||
|
expect(updateAccountMock).toHaveBeenCalledTimes(1)
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([
|
||||||
|
'embeddings'
|
||||||
|
])
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra).not.toHaveProperty('openai_responses_mode')
|
||||||
|
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
it('submits account-level Codex image generation bridge override', async () => {
|
it('submits account-level Codex image generation bridge override', async () => {
|
||||||
const account = buildAccount()
|
const account = buildAccount()
|
||||||
account.extra = {
|
account.extra = {
|
||||||
|
|||||||
@ -35,6 +35,11 @@ describe('useModelWhitelist', () => {
|
|||||||
expect(models).toContain('gemini-3-pro-image')
|
expect(models).toContain('gemini-3-pro-image')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('Claude 模型列表包含 Opus 4.8', () => {
|
||||||
|
expect(getModelsByPlatform('claude')).toContain('claude-opus-4-8')
|
||||||
|
expect(getModelsByPlatform('antigravity')).toContain('claude-opus-4-8')
|
||||||
|
})
|
||||||
|
|
||||||
it('gemini 模型列表包含原生生图模型', () => {
|
it('gemini 模型列表包含原生生图模型', () => {
|
||||||
const models = getModelsByPlatform('gemini')
|
const models = getModelsByPlatform('gemini')
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ export const claudeModels = [
|
|||||||
'claude-opus-4-5-20251101',
|
'claude-opus-4-5-20251101',
|
||||||
'claude-opus-4-6',
|
'claude-opus-4-6',
|
||||||
'claude-opus-4-7',
|
'claude-opus-4-7',
|
||||||
|
'claude-opus-4-8',
|
||||||
'claude-sonnet-4-6'
|
'claude-sonnet-4-6'
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ const antigravityModels = [
|
|||||||
'claude-opus-4-6',
|
'claude-opus-4-6',
|
||||||
'claude-opus-4-6-thinking',
|
'claude-opus-4-6-thinking',
|
||||||
'claude-opus-4-7',
|
'claude-opus-4-7',
|
||||||
|
'claude-opus-4-8',
|
||||||
'claude-opus-4-5-thinking',
|
'claude-opus-4-5-thinking',
|
||||||
'claude-sonnet-4-6',
|
'claude-sonnet-4-6',
|
||||||
'claude-sonnet-4-5',
|
'claude-sonnet-4-5',
|
||||||
@ -263,6 +265,7 @@ const anthropicPresetMappings = [
|
|||||||
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
{ label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
|
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'claude-opus-4-8', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
|
||||||
{ label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
{ label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' },
|
||||||
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4-5-20251001', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
{ label: 'Haiku 4.5', from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4-5-20251001', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' },
|
||||||
{ label: 'Opus->Sonnet', from: 'claude-opus-4-6', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
{ label: 'Opus->Sonnet', from: 'claude-opus-4-6', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }
|
||||||
@ -322,7 +325,8 @@ const antigravityPresetMappings = [
|
|||||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
|
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'claude-opus-4-8', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Windsurf 预设映射
|
// Windsurf 预设映射
|
||||||
@ -332,6 +336,7 @@ const windsurfPresetMappings: { label: string; from: string; to: string; color:
|
|||||||
const bedrockPresetMappings = [
|
const bedrockPresetMappings = [
|
||||||
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
{ label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'us.anthropic.claude-opus-4-7-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
{ label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'us.anthropic.claude-opus-4-7-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
|
{ label: 'Opus 4.8', from: 'claude-opus-4-8', to: 'us.anthropic.claude-opus-4-8-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
{ label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
|
||||||
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
|
||||||
|
|||||||
@ -3361,10 +3361,21 @@ export default {
|
|||||||
'Automatic passthrough is currently enabled: it only affects HTTP passthrough and does not disable WS mode.',
|
'Automatic passthrough is currently enabled: it only affects HTTP passthrough and does not disable WS mode.',
|
||||||
responsesMode: 'Responses API support',
|
responsesMode: 'Responses API support',
|
||||||
responsesModeDesc:
|
responsesModeDesc:
|
||||||
'Only applies to OpenAI API Key accounts. Auto follows probe results; force modes override probing.',
|
'Only applies to the OpenAI API Key text forwarding path. Auto follows probe results; force modes override probing.',
|
||||||
responsesModeAuto: 'Auto',
|
responsesModeAuto: 'Auto',
|
||||||
responsesModeForceResponses: 'Force Responses',
|
responsesModeForceResponses: 'Force Responses',
|
||||||
responsesModeForceChatCompletions: 'Force Chat Completions',
|
responsesModeForceChatCompletions: 'Force Chat Completions',
|
||||||
|
responsesModeTextDisabledHint:
|
||||||
|
'Not applicable when the Responses / Chat Completions endpoint is not enabled.',
|
||||||
|
endpointCapabilities: 'Endpoint capabilities',
|
||||||
|
endpointCapabilitiesDesc:
|
||||||
|
'Used by account routing. The text endpoint follows the Responses API support setting above and is shown as Responses, Chat Completions, or auto mode; Embeddings independently controls /v1/embeddings.',
|
||||||
|
capabilityResponses: 'Responses',
|
||||||
|
capabilityTextAuto: 'Responses / Chat Completions (Auto)',
|
||||||
|
capabilityResponsesAuto: 'Responses (auto probe)',
|
||||||
|
capabilityChatCompletions: 'Chat Completions',
|
||||||
|
capabilityChatCompletionsAuto: 'Chat Completions (auto probe)',
|
||||||
|
capabilityEmbeddings: 'Embeddings',
|
||||||
responsesStatusAutoSupported: 'Auto probe: Responses',
|
responsesStatusAutoSupported: 'Auto probe: Responses',
|
||||||
responsesStatusAutoUnsupported: 'Auto probe: Chat Completions',
|
responsesStatusAutoUnsupported: 'Auto probe: Chat Completions',
|
||||||
responsesStatusAutoUnknown: 'Auto probe: unknown',
|
responsesStatusAutoUnknown: 'Auto probe: unknown',
|
||||||
@ -3373,6 +3384,9 @@ export default {
|
|||||||
codexCLIOnly: 'Codex official clients only',
|
codexCLIOnly: 'Codex official clients only',
|
||||||
codexCLIOnlyDesc:
|
codexCLIOnlyDesc:
|
||||||
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
|
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
|
||||||
|
codexCLIOnlyAllowClaudeCode: "Also allow Claude Code's Codex plugin",
|
||||||
|
codexCLIOnlyAllowClaudeCodeDesc:
|
||||||
|
'Only takes effect when the switch above is on. Additionally allows requests from the Claude Code Codex plugin (exact match on originator=Claude Code) without weakening blocking of other non-official clients.',
|
||||||
codexImageGenerationBridge: 'Codex image-generation bridge',
|
codexImageGenerationBridge: 'Codex image-generation bridge',
|
||||||
codexImageGenerationBridgeDesc:
|
codexImageGenerationBridgeDesc:
|
||||||
'Account policy takes precedence over channel and global settings. Only controls whether Codex requests through the /responses text endpoint receive the image_generation tool; standalone image-generation endpoints are unaffected.',
|
'Account policy takes precedence over channel and global settings. Only controls whether Codex requests through the /responses text endpoint receive the image_generation tool; standalone image-generation endpoints are unaffected.',
|
||||||
@ -3473,6 +3487,12 @@ export default {
|
|||||||
'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
|
'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens',
|
||||||
autoPauseOnExpired: 'Auto Pause On Expired',
|
autoPauseOnExpired: 'Auto Pause On Expired',
|
||||||
autoPauseOnExpiredDesc: 'When enabled, the account will auto pause scheduling after it expires',
|
autoPauseOnExpiredDesc: 'When enabled, the account will auto pause scheduling after it expires',
|
||||||
|
autoPause5hThreshold: '5h Usage Threshold (%)',
|
||||||
|
autoPause7dThreshold: '7d Usage Threshold (%)',
|
||||||
|
autoPauseThresholdHint: 'Leave empty or set 0 to use the global default threshold (configured in Ops settings); set a value to override the global default. Reaching the threshold only skips the account during scheduling and does not modify schedulable.',
|
||||||
|
autoPause5hDisabled: 'Disable 5h auto-pause',
|
||||||
|
autoPause7dDisabled: 'Disable 7d auto-pause',
|
||||||
|
autoPauseDisabledHint: 'When enabled, this account is never auto-paused (even if a global default threshold is configured).',
|
||||||
// Quota control (Anthropic OAuth/SetupToken only)
|
// Quota control (Anthropic OAuth/SetupToken only)
|
||||||
quotaControl: {
|
quotaControl: {
|
||||||
title: 'Quota Control',
|
title: 'Quota Control',
|
||||||
@ -5203,6 +5223,11 @@ export default {
|
|||||||
aggregation: 'Pre-aggregation Tasks',
|
aggregation: 'Pre-aggregation Tasks',
|
||||||
enableAggregation: 'Enable Pre-aggregation',
|
enableAggregation: 'Enable Pre-aggregation',
|
||||||
aggregationHint: 'Pre-aggregation improves query performance for long time windows',
|
aggregationHint: 'Pre-aggregation improves query performance for long time windows',
|
||||||
|
openaiQuotaAutoPause: 'OpenAI Account Quota Auto-pause',
|
||||||
|
openaiQuotaAutoPauseHint: 'When an OpenAI account reaches its 5h / 7d usage threshold, the scheduler skips it automatically and resumes once the window rolls over. Per-account thresholds take precedence over this global default.',
|
||||||
|
openaiQuotaAutoPauseDefault5h: 'Default 5h usage threshold (%)',
|
||||||
|
openaiQuotaAutoPauseDefault7d: 'Default 7d usage threshold (%)',
|
||||||
|
openaiQuotaAutoPauseThresholdHint: 'Value 0-100; leave blank or 0 to disable the global default threshold.',
|
||||||
errorFiltering: 'Error Filtering',
|
errorFiltering: 'Error Filtering',
|
||||||
ignoreCountTokensErrors: 'Ignore count_tokens errors',
|
ignoreCountTokensErrors: 'Ignore count_tokens errors',
|
||||||
ignoreCountTokensErrorsHint: 'When enabled, errors from count_tokens requests will not be written to the error log.',
|
ignoreCountTokensErrorsHint: 'When enabled, errors from count_tokens requests will not be written to the error log.',
|
||||||
@ -5233,7 +5258,8 @@ export default {
|
|||||||
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
|
slaMinPercentRange: 'SLA minimum percentage must be between 0 and 100',
|
||||||
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
|
ttftP99MaxRange: 'TTFT P99 maximum must be a number ≥ 0',
|
||||||
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
|
requestErrorRateMaxRange: 'Request error rate maximum must be between 0 and 100',
|
||||||
upstreamErrorRateMaxRange: 'Upstream error rate maximum must be between 0 and 100'
|
upstreamErrorRateMaxRange: 'Upstream error rate maximum must be between 0 and 100',
|
||||||
|
openaiQuotaAutoPauseRange: 'OpenAI quota auto-pause threshold must be between 0 and 100'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
concurrency: {
|
concurrency: {
|
||||||
@ -5627,6 +5653,9 @@ export default {
|
|||||||
openaiCodexUserAgent: 'OpenAI Codex UA',
|
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||||
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||||
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
|
openaiCodexUserAgentHint: 'Used to bypass Cloudflare browser-UA challenges on the OpenAI upstream. Only applies when the client User-Agent is detected as a browser (Mozilla/...). Leave empty to use the built-in default.',
|
||||||
|
openaiAllowClaudeCodeCodexPlugin: "Allow using the Codex plugin in Claude Code",
|
||||||
|
openaiAllowClaudeCodeCodexPluginDesc:
|
||||||
|
"Global switch; only affects OpenAI OAuth accounts that have 'Codex official clients only' enabled. When on, all such accounts additionally allow requests from the Claude Code Codex plugin (exact match on originator=Claude Code) without per-account config; upstream requests remain pass-through.",
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search Emulation',
|
title: 'Web Search Emulation',
|
||||||
|
|||||||
@ -3507,10 +3507,20 @@ export default {
|
|||||||
responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
|
responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
|
||||||
responsesMode: 'Responses API 支持',
|
responsesMode: 'Responses API 支持',
|
||||||
responsesModeDesc:
|
responsesModeDesc:
|
||||||
'仅对 OpenAI API Key 生效。自动跟随探测结果,强制模式会覆盖自动探测。',
|
'仅对 OpenAI API Key 的文本转发链路生效。自动跟随探测结果,强制模式会覆盖自动探测。',
|
||||||
responsesModeAuto: '自动',
|
responsesModeAuto: '自动',
|
||||||
responsesModeForceResponses: '强制 Responses',
|
responsesModeForceResponses: '强制 Responses',
|
||||||
responsesModeForceChatCompletions: '强制 Chat Completions',
|
responsesModeForceChatCompletions: '强制 Chat Completions',
|
||||||
|
responsesModeTextDisabledHint: '未启用 Responses / Chat Completions 端点时,此设置不适用。',
|
||||||
|
endpointCapabilities: '端点能力',
|
||||||
|
endpointCapabilitiesDesc:
|
||||||
|
'用于调度筛选。文本端点会跟随上方 Responses API 支持显示为 Responses、Chat Completions 或自动模式;Embeddings 独立控制 /v1/embeddings。',
|
||||||
|
capabilityResponses: 'Responses',
|
||||||
|
capabilityTextAuto: 'Responses / Chat Completions(自动)',
|
||||||
|
capabilityResponsesAuto: 'Responses(自动探测)',
|
||||||
|
capabilityChatCompletions: 'Chat Completions',
|
||||||
|
capabilityChatCompletionsAuto: 'Chat Completions(自动探测)',
|
||||||
|
capabilityEmbeddings: 'Embeddings',
|
||||||
responsesStatusAutoSupported: '自动探测:Responses',
|
responsesStatusAutoSupported: '自动探测:Responses',
|
||||||
responsesStatusAutoUnsupported: '自动探测:Chat Completions',
|
responsesStatusAutoUnsupported: '自动探测:Chat Completions',
|
||||||
responsesStatusAutoUnknown: '自动探测:未探测',
|
responsesStatusAutoUnknown: '自动探测:未探测',
|
||||||
@ -3518,6 +3528,8 @@ export default {
|
|||||||
responsesStatusForcedChatCompletions: '已强制 Chat Completions',
|
responsesStatusForcedChatCompletions: '已强制 Chat Completions',
|
||||||
codexCLIOnly: '仅允许 Codex 官方客户端',
|
codexCLIOnly: '仅允许 Codex 官方客户端',
|
||||||
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
|
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
|
||||||
|
codexCLIOnlyAllowClaudeCode: '额外放行 Claude Code 的 Codex 插件',
|
||||||
|
codexCLIOnlyAllowClaudeCodeDesc: '仅在上方开关开启时生效。额外放行通过 Claude Code 的 Codex 插件发起的请求(精确匹配 originator=Claude Code),不影响对其他非官方客户端的拦截。',
|
||||||
codexImageGenerationBridge: 'Codex 图片生成桥接',
|
codexImageGenerationBridge: 'Codex 图片生成桥接',
|
||||||
codexImageGenerationBridgeDesc:
|
codexImageGenerationBridgeDesc:
|
||||||
'账号级策略优先于渠道和全局配置。仅控制 Codex 走 /responses 文本端点时是否注入 image_generation 工具;不影响独立图片生成接口。',
|
'账号级策略优先于渠道和全局配置。仅控制 Codex 走 /responses 文本端点时是否注入 image_generation 工具;不影响独立图片生成接口。',
|
||||||
@ -3613,6 +3625,12 @@ export default {
|
|||||||
interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token',
|
interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token',
|
||||||
autoPauseOnExpired: '过期自动暂停调度',
|
autoPauseOnExpired: '过期自动暂停调度',
|
||||||
autoPauseOnExpiredDesc: '启用后,账号过期将自动暂停调度',
|
autoPauseOnExpiredDesc: '启用后,账号过期将自动暂停调度',
|
||||||
|
autoPause5hThreshold: '5h 用量阈值(%)',
|
||||||
|
autoPause7dThreshold: '7d 用量阈值(%)',
|
||||||
|
autoPauseThresholdHint: '留空或填 0 表示使用全局默认阈值(在运维设置中配置);填具体值则覆盖全局默认。达到阈值后仅在调度时跳过账号,不修改 schedulable。',
|
||||||
|
autoPause5hDisabled: '禁用 5h 自动暂停',
|
||||||
|
autoPause7dDisabled: '禁用 7d 自动暂停',
|
||||||
|
autoPauseDisabledHint: '开启后该账号永不进入自动暂停(即使全局默认阈值已配置)。',
|
||||||
// Quota control (Anthropic OAuth/SetupToken only)
|
// Quota control (Anthropic OAuth/SetupToken only)
|
||||||
quotaControl: {
|
quotaControl: {
|
||||||
title: '配额控制',
|
title: '配额控制',
|
||||||
@ -5364,6 +5382,11 @@ export default {
|
|||||||
aggregation: '预聚合任务',
|
aggregation: '预聚合任务',
|
||||||
enableAggregation: '启用预聚合任务',
|
enableAggregation: '启用预聚合任务',
|
||||||
aggregationHint: '预聚合可提升长时间窗口查询性能',
|
aggregationHint: '预聚合可提升长时间窗口查询性能',
|
||||||
|
openaiQuotaAutoPause: 'OpenAI 账号配额自动暂停',
|
||||||
|
openaiQuotaAutoPauseHint: '当 OpenAI 账号 5h / 7d 用量达到阈值时,调度会自动跳过该账号;窗口滚动后自动恢复。账号级阈值优先于此全局默认值。',
|
||||||
|
openaiQuotaAutoPauseDefault5h: '默认 5h 用量阈值 (%)',
|
||||||
|
openaiQuotaAutoPauseDefault7d: '默认 7d 用量阈值 (%)',
|
||||||
|
openaiQuotaAutoPauseThresholdHint: '取值 0-100,留空或 0 表示不启用全局默认阈值。',
|
||||||
errorFiltering: '错误过滤',
|
errorFiltering: '错误过滤',
|
||||||
ignoreCountTokensErrors: '忽略 count_tokens 错误',
|
ignoreCountTokensErrors: '忽略 count_tokens 错误',
|
||||||
ignoreCountTokensErrorsHint: '启用后,count_tokens 请求的错误将不会写入错误日志。',
|
ignoreCountTokensErrorsHint: '启用后,count_tokens 请求的错误将不会写入错误日志。',
|
||||||
@ -5395,7 +5418,8 @@ export default {
|
|||||||
slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
|
slaMinPercentRange: 'SLA最低百分比必须在0-100之间',
|
||||||
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
|
ttftP99MaxRange: 'TTFT P99最大值必须大于等于0',
|
||||||
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
|
requestErrorRateMaxRange: '请求错误率最大值必须在0-100之间',
|
||||||
upstreamErrorRateMaxRange: '上游错误率最大值必须在0-100之间'
|
upstreamErrorRateMaxRange: '上游错误率最大值必须在0-100之间',
|
||||||
|
openaiQuotaAutoPauseRange: 'OpenAI 配额自动暂停阈值必须在 0-100 之间'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
concurrency: {
|
concurrency: {
|
||||||
@ -5783,6 +5807,9 @@ export default {
|
|||||||
openaiCodexUserAgent: 'OpenAI Codex UA',
|
openaiCodexUserAgent: 'OpenAI Codex UA',
|
||||||
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
openaiCodexUserAgentPlaceholder: 'codex-tui/0.125.0 (Ubuntu 22.4.0; x86_64) xterm-256color (codex-tui; 0.125.0)',
|
||||||
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器(Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
|
openaiCodexUserAgentHint: '用于规避 OpenAI 上游 Cloudflare 对浏览器 UA 的访问质询。仅在检测到客户端 User-Agent 为浏览器(Mozilla/...)时生效,其他客户端原样透传。留空使用内置默认值。',
|
||||||
|
openaiAllowClaudeCodeCodexPlugin: '允许在 Claude Code 中使用 Codex 插件',
|
||||||
|
openaiAllowClaudeCodeCodexPluginDesc:
|
||||||
|
'全局开关,仅对已开启「仅允许 Codex 官方客户端」的 OpenAI OAuth 账号生效。开启后,所有此类账号都额外放行通过 Claude Code 的 Codex 插件发起的请求(精确匹配 originator=Claude Code),无需逐账号配置;上游请求仍保持透传。',
|
||||||
},
|
},
|
||||||
webSearchEmulation: {
|
webSearchEmulation: {
|
||||||
title: 'Web Search 模拟',
|
title: 'Web Search 模拟',
|
||||||
|
|||||||
@ -1185,6 +1185,7 @@ export interface CodexUsageSnapshot {
|
|||||||
|
|
||||||
export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off'
|
export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off'
|
||||||
export type OpenAIResponsesMode = 'auto' | 'force_responses' | 'force_chat_completions'
|
export type OpenAIResponsesMode = 'auto' | 'force_responses' | 'force_chat_completions'
|
||||||
|
export type OpenAIEndpointCapability = 'chat_completions' | 'embeddings'
|
||||||
|
|
||||||
export interface OpenAICompactState {
|
export interface OpenAICompactState {
|
||||||
openai_compact_mode?: OpenAICompactMode
|
openai_compact_mode?: OpenAICompactMode
|
||||||
|
|||||||
@ -3948,6 +3948,19 @@
|
|||||||
}}
|
}}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 是否允许在 Claude Code 中使用 Codex 插件(全局开关) -->
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<div class="pr-4">
|
||||||
|
<label class="block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t("admin.settings.gatewayForwarding.openaiAllowClaudeCodeCodexPlugin") }}
|
||||||
|
</label>
|
||||||
|
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||||
|
{{ t("admin.settings.gatewayForwarding.openaiAllowClaudeCodeCodexPluginDesc") }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Toggle v-model="form.openai_allow_claude_code_codex_plugin" />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<!-- Web Search Emulation -->
|
<!-- Web Search Emulation -->
|
||||||
@ -7163,6 +7176,7 @@ const form = reactive<SettingsForm>({
|
|||||||
rewrite_message_cache_control: false,
|
rewrite_message_cache_control: false,
|
||||||
antigravity_user_agent_version: "",
|
antigravity_user_agent_version: "",
|
||||||
openai_codex_user_agent: "",
|
openai_codex_user_agent: "",
|
||||||
|
openai_allow_claude_code_codex_plugin: false,
|
||||||
// 余额、订阅到期与账号限额通知
|
// 余额、订阅到期与账号限额通知
|
||||||
balance_low_notify_enabled: false,
|
balance_low_notify_enabled: false,
|
||||||
balance_low_notify_threshold: 0,
|
balance_low_notify_threshold: 0,
|
||||||
@ -8269,6 +8283,7 @@ async function saveSettings() {
|
|||||||
form.antigravity_user_agent_version?.trim() || "",
|
form.antigravity_user_agent_version?.trim() || "",
|
||||||
openai_codex_user_agent:
|
openai_codex_user_agent:
|
||||||
form.openai_codex_user_agent?.trim() || "",
|
form.openai_codex_user_agent?.trim() || "",
|
||||||
|
openai_allow_claude_code_codex_plugin: form.openai_allow_claude_code_codex_plugin,
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
payment_enabled: form.payment_enabled,
|
payment_enabled: form.payment_enabled,
|
||||||
risk_control_enabled: form.risk_control_enabled,
|
risk_control_enabled: form.risk_control_enabled,
|
||||||
|
|||||||
@ -50,6 +50,10 @@ async function loadAllSettings() {
|
|||||||
runtimeSettings.value = runtime
|
runtimeSettings.value = runtime
|
||||||
emailConfig.value = email
|
emailConfig.value = email
|
||||||
advancedSettings.value = advanced
|
advancedSettings.value = advanced
|
||||||
|
// 兼容旧 payload:后端未返回该字段时补默认值,保证表单可绑定
|
||||||
|
if (advancedSettings.value && !advancedSettings.value.openai_account_quota_auto_pause) {
|
||||||
|
advancedSettings.value.openai_account_quota_auto_pause = { default_threshold_5h: 0, default_threshold_7d: 0 }
|
||||||
|
}
|
||||||
// 如果后端返回了阈值,使用后端的值;否则保持默认值
|
// 如果后端返回了阈值,使用后端的值;否则保持默认值
|
||||||
if (thresholds && Object.keys(thresholds).length > 0) {
|
if (thresholds && Object.keys(thresholds).length > 0) {
|
||||||
metricThresholds.value = {
|
metricThresholds.value = {
|
||||||
@ -119,6 +123,28 @@ function removeRecipient(target: 'alert' | 'report', email: string) {
|
|||||||
if (idx >= 0) list.splice(idx, 1)
|
if (idx >= 0) list.splice(idx, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI 账号配额自动暂停:后端按 0~1 分数存储,UI 按百分比(0~100)展示
|
||||||
|
const quotaAutoPause5hPercent = computed<number | null>({
|
||||||
|
get() {
|
||||||
|
const v = advancedSettings.value?.openai_account_quota_auto_pause?.default_threshold_5h
|
||||||
|
return v && v > 0 ? Math.round(v * 1000) / 10 : null
|
||||||
|
},
|
||||||
|
set(val) {
|
||||||
|
if (!advancedSettings.value?.openai_account_quota_auto_pause) return
|
||||||
|
advancedSettings.value.openai_account_quota_auto_pause.default_threshold_5h = val != null && val > 0 ? val / 100 : 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
const quotaAutoPause7dPercent = computed<number | null>({
|
||||||
|
get() {
|
||||||
|
const v = advancedSettings.value?.openai_account_quota_auto_pause?.default_threshold_7d
|
||||||
|
return v && v > 0 ? Math.round(v * 1000) / 10 : null
|
||||||
|
},
|
||||||
|
set(val) {
|
||||||
|
if (!advancedSettings.value?.openai_account_quota_auto_pause) return
|
||||||
|
advancedSettings.value.openai_account_quota_auto_pause.default_threshold_7d = val != null && val > 0 ? val / 100 : 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// 验证
|
// 验证
|
||||||
const validation = computed(() => {
|
const validation = computed(() => {
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
@ -145,6 +171,11 @@ const validation = computed(() => {
|
|||||||
if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) {
|
if (hourly_metrics_retention_days < 0 || hourly_metrics_retention_days > 365) {
|
||||||
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
errors.push(t('admin.ops.settings.validation.retentionDaysRange'))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const { default_threshold_5h, default_threshold_7d } = advancedSettings.value.openai_account_quota_auto_pause
|
||||||
|
if (default_threshold_5h < 0 || default_threshold_5h > 1 || default_threshold_7d < 0 || default_threshold_7d > 1) {
|
||||||
|
errors.push(t('admin.ops.settings.validation.openaiQuotaAutoPauseRange'))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证指标阈值
|
// 验证指标阈值
|
||||||
@ -473,6 +504,40 @@ async function saveAllSettings() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- OpenAI 账号配额自动暂停(全局默认阈值) -->
|
||||||
|
<div class="space-y-3">
|
||||||
|
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.openaiQuotaAutoPause') }}</h5>
|
||||||
|
<p class="text-xs text-gray-500">{{ t('admin.ops.settings.openaiQuotaAutoPauseHint') }}</p>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.ops.settings.openaiQuotaAutoPauseDefault5h') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="quotaAutoPause5hPercent"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="100"
|
||||||
|
step="0.1"
|
||||||
|
class="input"
|
||||||
|
data-testid="ops-quota-auto-pause-5h"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="input-label">{{ t('admin.ops.settings.openaiQuotaAutoPauseDefault7d') }}</label>
|
||||||
|
<input
|
||||||
|
v-model.number="quotaAutoPause7dPercent"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="100"
|
||||||
|
step="0.1"
|
||||||
|
class="input"
|
||||||
|
data-testid="ops-quota-auto-pause-7d"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<p class="text-xs text-gray-500">{{ t('admin.ops.settings.openaiQuotaAutoPauseThresholdHint') }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Error Filtering -->
|
<!-- Error Filtering -->
|
||||||
<div class="space-y-3">
|
<div class="space-y-3">
|
||||||
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.errorFiltering') }}</h5>
|
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.errorFiltering') }}</h5>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user