diff --git a/.gitignore b/.gitignore
index bf7ee064..cf251f07 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
docs/claude-relay-service/
+.codex
# ===================
# Go 后端
@@ -121,7 +122,7 @@ scripts
.code-review-state
#openspec/
code-reviews/
-#AGENTS.md
+AGENTS.md
backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
diff --git a/README.md b/README.md
index 3e609d65..718730c6 100644
--- a/README.md
+++ b/README.md
@@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
+Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
+Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
+Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.
+
+
## Ecosystem
diff --git a/README_CN.md b/README_CN.md
index add32a17..24600e0e 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
+感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
+同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
+现在通过 此链接 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。
+
+
## 生态项目
diff --git a/README_JA.md b/README_JA.md
index ccd595b9..1e89610c 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
+PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
+エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
+こちらのリンク から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。
+
+
## エコシステム
diff --git a/assets/partners/logos/pateway.png b/assets/partners/logos/pateway.png
new file mode 100644
index 00000000..7ca3489a
Binary files /dev/null and b/assets/partners/logos/pateway.png differ
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 1fcba8fa..025c3166 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.118
+0.1.121
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 6aaaa0c4..ff385516 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
- apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
+ apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
@@ -143,6 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
@@ -155,7 +156,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository)
windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache)
windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository)
- accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
+ accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -182,7 +183,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
@@ -191,7 +191,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 1b0a79ec..d52f5342 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -33,6 +33,7 @@ const (
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeWindsurfSession = "windsurf-session" // Windsurf Session 类型账号(邮箱密码登录获取的 session token + api_key)
+ AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 4dcfaa6b..c107c329 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -99,7 +99,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
- Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
+ Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -118,7 +118,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
- Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
+ Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -135,19 +135,29 @@ type UpdateAccountRequest struct {
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
- AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
- Name string `json:"name"`
- ProxyID *int64 `json:"proxy_id"`
- Concurrency *int `json:"concurrency"`
- Priority *int `json:"priority"`
- RateMultiplier *float64 `json:"rate_multiplier"`
- LoadFactor *int `json:"load_factor"`
- Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
- Schedulable *bool `json:"schedulable"`
- GroupIDs *[]int64 `json:"group_ids"`
- Credentials map[string]any `json:"credentials"`
- Extra map[string]any `json:"extra"`
- ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
+ AccountIDs []int64 `json:"account_ids"`
+ Filters *BulkUpdateAccountFilters `json:"filters"`
+ Name string `json:"name"`
+ ProxyID *int64 `json:"proxy_id"`
+ Concurrency *int `json:"concurrency"`
+ Priority *int `json:"priority"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
+ LoadFactor *int `json:"load_factor"`
+ Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
+ Schedulable *bool `json:"schedulable"`
+ GroupIDs *[]int64 `json:"group_ids"`
+ Credentials map[string]any `json:"credentials"`
+ Extra map[string]any `json:"extra"`
+ ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
+}
+
+type BulkUpdateAccountFilters struct {
+ Platform string `json:"platform"`
+ Type string `json:"type"`
+ Status string `json:"status"`
+ Group string `json:"group"`
+ Search string `json:"search"`
+ PrivacyMode string `json:"privacy_mode"`
}
// CheckMixedChannelRequest represents check mixed channel risk request
@@ -1370,6 +1380,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
+ if len(req.AccountIDs) == 0 && req.Filters == nil {
+ response.BadRequest(c, "account_ids or filters is required")
+ return
+ }
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
@@ -1395,6 +1409,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
+ Filters: toServiceBulkUpdateAccountFilters(req.Filters),
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
@@ -1430,6 +1445,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.Success(c, result)
}
+func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
+ if filters == nil {
+ return nil
+ }
+ return &service.BulkUpdateAccountFilters{
+ Platform: filters.Platform,
+ Type: filters.Type,
+ Status: filters.Status,
+ Group: filters.Group,
+ Search: filters.Search,
+ PrivacyMode: filters.PrivacyMode,
+ }
+}
+
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
index 24ec5bcf..929dc240 100644
--- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go
+++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}
+
+func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountMixedChannelRouter(adminSvc)
+
+ body, _ := json.Marshal(map[string]any{
+ "filters": map[string]any{
+ "platform": "openai",
+ "type": "oauth",
+ "status": "active",
+ "group": "12",
+ "privacy_mode": "blocked",
+ "search": "bulk-target",
+ },
+ "schedulable": true,
+ })
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, float64(0), resp["code"])
+}
diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go
index 3833d32e..6df49154 100644
--- a/backend/internal/handler/admin/admin_helpers_test.go
+++ b/backend/internal/handler/admin/admin_helpers_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}
+
+// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
+// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
+// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
+func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
+ t.Run("nil input returns nil", func(t *testing.T) {
+ require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
+ })
+
+ t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: "",
+ Action: "filter",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.NotNil(t, out)
+ require.Len(t, out.Rules, 1)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
+ require.Equal(t, "all", out.Rules[0].ServiceTier)
+ })
+
+ t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: " ",
+ Action: "pass",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
+ })
+
+ t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: "PRIORITY",
+ Action: "filter",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
+ })
+
+ t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{
+ {ServiceTier: "priority", Action: "filter", Scope: "all"},
+ {ServiceTier: "flex", Action: "block", Scope: "oauth"},
+ {ServiceTier: "all", Action: "pass", Scope: "apikey"},
+ },
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Len(t, out.Rules, 3)
+ require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
+ require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
+ })
+}
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 2fe29fa3..b187b47f 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
+func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
+ for i := range s.apiKeys {
+ if s.apiKeys[i].ID == keyID {
+ s.apiKeys[i].Usage5h = 0
+ s.apiKeys[i].Usage1d = 0
+ s.apiKeys[i].Usage7d = 0
+ s.apiKeys[i].Window5hStart = nil
+ s.apiKeys[i].Window1dStart = nil
+ s.apiKeys[i].Window7dStart = nil
+ k := s.apiKeys[i]
+ return &k, nil
+ }
+ }
+ return nil, service.ErrAPIKeyNotFound
+}
+
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go
index 8dd245a4..5e405bdd 100644
--- a/backend/internal/handler/admin/apikey_handler.go
+++ b/backend/internal/handler/admin/apikey_handler.go
@@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
}
}
-// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
+// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
type AdminUpdateAPIKeyGroupRequest struct {
- GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+ GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+ ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
}
-// UpdateGroup handles updating an API key's group binding
+// UpdateGroup handles updating an API key's admin-managed fields.
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
return
}
+ var resetKey *service.APIKey
+ if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
+ resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ if resetKey != nil && req.GroupID == nil {
+ result.APIKey = resetKey
+ }
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go
index bf128b18..6ac6d52f 100644
--- a/backend/internal/handler/admin/apikey_handler_test.go
+++ b/backend/internal/handler/admin/apikey_handler_test.go
@@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
require.Nil(t, resp.Data.APIKey.GroupID)
}
+func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
+ svc := newStubAdminService()
+ now := time.Now()
+ svc.apiKeys[0].Usage5h = 1.2
+ svc.apiKeys[0].Usage1d = 3.4
+ svc.apiKeys[0].Usage7d = 5.6
+ svc.apiKeys[0].Window5hStart = &now
+ svc.apiKeys[0].Window1dStart = &now
+ svc.apiKeys[0].Window7dStart = &now
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Data struct {
+ APIKey struct {
+ Usage5h float64 `json:"usage_5h"`
+ Usage1d float64 `json:"usage_1d"`
+ Usage7d float64 `json:"usage_7d"`
+ Window5hStart *time.Time `json:"window_5h_start"`
+ Window1dStart *time.Time `json:"window_1d_start"`
+ Window7dStart *time.Time `json:"window_7d_start"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Zero(t, resp.Data.APIKey.Usage5h)
+ require.Zero(t, resp.Data.APIKey.Usage1d)
+ require.Zero(t, resp.Data.APIKey.Usage7d)
+ require.Nil(t, resp.Data.APIKey.Window5hStart)
+ require.Nil(t, resp.Data.APIKey.Window1dStart)
+ require.Nil(t, resp.Data.APIKey.Window7dStart)
+}
+
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 40bf1c69..59f4fe85 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
AffiliateRebateRate: settings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
@@ -206,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
+ EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
@@ -245,9 +249,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateEnabled: settings.AffiliateEnabled,
}
+
+ // OpenAI fast policy (stored under a dedicated setting key)
+ if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
+ slog.Error("openai_fast_policy_settings_get_failed", "error", err)
+ } else if fastPolicy != nil {
+ payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
+ }
+
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
+// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
+func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
+ if s == nil {
+ return nil
+ }
+ rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
+ for i, r := range s.Rules {
+ rules[i] = dto.OpenAIFastPolicyRule(r)
+ }
+ return &dto.OpenAIFastPolicySettings{Rules: rules}
+}
+
+// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
+//
+// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
+// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
+// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
+// service.SetOpenAIFastPolicySettings 负责合法值校验。
+func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
+ if s == nil {
+ return nil
+ }
+ rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
+ for i, r := range s.Rules {
+ rules[i] = service.OpenAIFastPolicyRule(r)
+ tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
+ if tier == "" {
+ tier = service.OpenAIFastTierAny
+ }
+ rules[i].ServiceTier = tier
+ }
+ return &service.OpenAIFastPolicySettings{Rules: rules}
+}
+
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
@@ -342,6 +388,9 @@ type UpdateSettingsRequest struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
@@ -393,9 +442,10 @@ type UpdateSettingsRequest struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
- EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
- EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
- EnableCCHSigning *bool `json:"enable_cch_signing"`
+ EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
+ EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning *bool `json:"enable_cch_signing"`
+ EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
@@ -446,6 +496,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
+
+ // OpenAI fast/flex policy (optional, only updated when provided)
+ OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
// UpdateSettings 更新系统设置
@@ -485,6 +538,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if affiliateRebateRate > service.AffiliateRebateRateMax {
affiliateRebateRate = service.AffiliateRebateRateMax
}
+ affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
+ if req.AffiliateRebateFreezeHours != nil {
+ affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
+ }
+ if affiliateRebateFreezeHours < 0 {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
+ }
+ if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
+ }
+ affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
+ if req.AffiliateRebateDurationDays != nil {
+ affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
+ }
+ if affiliateRebateDurationDays < 0 {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
+ }
+ if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
+ }
+ affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
+ if req.AffiliateRebatePerInviteeCap != nil {
+ affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
+ }
+ if affiliateRebatePerInviteeCap < 0 {
+ affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
+ }
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -1137,6 +1217,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
+ AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: affiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
@@ -1192,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ EnableAnthropicCacheTTL1hInjection: func() bool {
+ if req.EnableAnthropicCacheTTL1hInjection != nil {
+ return *req.EnableAnthropicCacheTTL1hInjection
+ }
+ return previousSettings.EnableAnthropicCacheTTL1hInjection
+ }(),
PaymentVisibleMethodAlipaySource: func() string {
if req.PaymentVisibleMethodAlipaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
@@ -1314,6 +1403,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
+ // Update OpenAI fast policy (stored under dedicated key, only when provided).
+ if req.OpenAIFastPolicySettings != nil {
+ if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ }
+
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
@@ -1458,6 +1555,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
@@ -1478,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
@@ -1516,6 +1617,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
+ if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
+ slog.Error("openai_fast_policy_settings_get_failed", "error", err)
+ } else if fastPolicy != nil {
+ payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
+ }
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@@ -1768,6 +1874,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AffiliateRebateRate != after.AffiliateRebateRate {
changed = append(changed, "affiliate_rebate_rate")
}
+ if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
+ changed = append(changed, "affiliate_rebate_freeze_hours")
+ }
+ if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
+ changed = append(changed, "affiliate_rebate_duration_days")
+ }
+ if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
+ changed = append(changed, "affiliate_rebate_per_invitee_cap")
+ }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
@@ -1843,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
+ changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
+ }
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
changed = append(changed, "payment_visible_method_alipay_source")
}
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
index 9a33a93a..085fd2ca 100644
--- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
- panic("unexpected GetValue call")
+ if s.values != nil {
+ if value, ok := s.values[key]; ok {
+ return value, nil
+ }
+ }
+ return "", nil
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 2ef05963..7df4abfd 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
type completeLinuxDoOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 604ad903..490afd0f 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct {
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
+ AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
+ strings.TrimSpace(req.AffCode),
); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 0ac8871b..4264002d 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index efee4cc0..34e70ed0 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
@@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 051fab18..492be170 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -106,11 +106,14 @@ type SystemSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
- DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
- DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -139,9 +142,10 @@ type SystemSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
- EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
- EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
- EnableCCHSigning bool `json:"enable_cch_signing"`
+ EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
+ EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning bool `json:"enable_cch_signing"`
+ EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
@@ -195,6 +199,9 @@ type SystemSettings struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
+
+ // OpenAI fast/flex policy
+ OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
type DefaultSubscriptionSetting struct {
@@ -291,6 +298,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
+// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
+type OpenAIFastPolicyRule struct {
+ ServiceTier string `json:"service_tier"`
+ Action string `json:"action"`
+ Scope string `json:"scope"`
+ ErrorMessage string `json:"error_message,omitempty"`
+ ModelWhitelist []string `json:"model_whitelist,omitempty"`
+ FallbackAction string `json:"fallback_action,omitempty"`
+ FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
+}
+
+// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
+type OpenAIFastPolicySettings struct {
+ Rules []OpenAIFastPolicyRule `json:"rules"`
+}
+
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 7bb9c46d..d12c2941 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -288,6 +288,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
+ // [DEBUG-STICKY] 打印会话 hash 生成结果
+ reqLog.Info("sticky.session_hash_generated",
+ zap.String("session_hash", sessionHash),
+ zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
+ )
+
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := ""
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
@@ -304,6 +310,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
+ // [DEBUG-STICKY] 打印粘性会话查询结果
+ reqLog.Info("sticky.cache_lookup",
+ zap.String("session_key", sessionKey),
+ zap.Int64("bound_account_id", sessionBoundAccountID),
+ )
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
@@ -312,6 +323,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
+ } else {
+ reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
@@ -591,6 +604,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
+ reqLog.Info("sticky.selecting_account",
+ zap.String("session_key", sessionKey),
+ zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
+ zap.Bool("has_bound_session", hasBoundSession),
+ zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
+ )
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
@@ -624,6 +643,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
+ // [DEBUG-STICKY] 打印账号选择结果
+ reqLog.Info("sticky.account_selected",
+ zap.Int64("selected_account_id", account.ID),
+ zap.String("account_name", account.Name),
+ zap.Bool("slot_acquired", selection.Acquired),
+ zap.Bool("has_wait_plan", selection.WaitPlan != nil),
+ zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
+ zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
+ )
+
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
@@ -690,6 +719,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// Slot acquired: no longer waiting in queue.
releaseWait()
+ reqLog.Info("sticky.bind_after_wait",
+ zap.String("session_key", sessionKey),
+ zap.Int64("account_id", account.ID),
+ )
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
@@ -924,6 +957,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
+ // 绑定粘性会话(成功转发后绑定/刷新)
+ // - 无现有绑定(首次请求):创建绑定
+ // - 选中账号与粘性账号一致:刷新 TTL
+ // - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
+ // 下次请求粘性账号恢复后仍可命中
+ if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
+ if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
+ reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ }
+
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index 71030140..57554cf9 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
+func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
+ return nil
+}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
index 403b41ef..4d0078a7 100644
--- a/backend/internal/handler/openai_images.go
+++ b/backend/internal/handler/openai_images.go
@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
- sessionHash := ""
- if parsed.Multipart {
- sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
- } else {
- sessionHash = h.gatewayService.GenerateSessionHash(c, body)
- }
+ sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index 37bd38b2..e7d8aab9 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
+ maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
@@ -42,17 +43,55 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
- if config[k] == "" {
+ if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
+ cfg := make(map[string]string, len(config))
+ for k, v := range config {
+ cfg[k] = v
+ }
+ cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{
instanceID: instanceID,
- config: config,
+ config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
+func normalizeEasyPayAPIBase(apiBase string) string {
+ base := strings.TrimSpace(apiBase)
+ if base == "" {
+ return ""
+ }
+ if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ parsed.RawQuery = ""
+ parsed.Fragment = ""
+ parsed.RawPath = ""
+ parsed.Path = trimEasyPayEndpointPath(parsed.Path)
+ return strings.TrimRight(parsed.String(), "/")
+ }
+ return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
+}
+
+func trimEasyPayEndpointPath(path string) string {
+ path = strings.TrimRight(strings.TrimSpace(path), "/")
+ lower := strings.ToLower(path)
+ for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
+ if strings.HasSuffix(lower, endpoint) {
+ return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
+ }
+ }
+ return path
+}
+
+func (e *EasyPay) apiBase() string {
+ if e == nil {
+ return ""
+ }
+ return normalizeEasyPayAPIBase(e.config["apiBase"])
+}
+
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
@@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params {
q.Set(k, v)
}
- base := strings.TrimRight(e.config["apiBase"], "/")
- payURL := base + "/submit.php?" + q.Encode()
+ payURL := e.apiBase() + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
@@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
- body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
+ body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
@@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
- body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
+ body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
@@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
- params := map[string]string{
- "pid": e.config["pid"], "key": e.config["pkey"],
- "trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
+ attempts := e.refundAttempts(req)
+ if len(attempts) == 0 {
+ return nil, fmt.Errorf("easypay refund missing order identifier")
}
- body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
- if err != nil {
- return nil, fmt.Errorf("easypay refund: %w", err)
+ var firstErr error
+ for i, attempt := range attempts {
+ body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
+ if err != nil {
+ return nil, fmt.Errorf("easypay refund request: %w", err)
+ }
+ if err := parseEasyPayRefundResponse(status, body); err != nil {
+ if firstErr == nil {
+ firstErr = err
+ }
+ if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
+ continue
+ }
+ return nil, err
+ }
+ return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
+ return nil, firstErr
+}
+
+type easyPayRefundAttempt struct {
+ params map[string]string
+ refundID string
+}
+
+func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
+ base := map[string]string{
+ "pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
+ }
+ var attempts []easyPayRefundAttempt
+ if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
+ params := cloneStringMap(base)
+ params["out_trade_no"] = orderID
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
+ }
+ if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
+ params := cloneStringMap(base)
+ params["trade_no"] = tradeNo
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
+ }
+ return attempts
+}
+
+func cloneStringMap(in map[string]string) map[string]string {
+ out := make(map[string]string, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func isEasyPayRefundOrderNotFound(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := err.Error()
+ lower := strings.ToLower(msg)
+ return strings.Contains(msg, "订单编号不存在") ||
+ strings.Contains(msg, "订单不存在") ||
+ strings.Contains(lower, "order not found") ||
+ strings.Contains(lower, "not exist")
+}
+
+func parseEasyPayRefundResponse(status int, body []byte) error {
+ summary := summarizeEasyPayResponse(body)
+ if status < http.StatusOK || status >= http.StatusMultipleChoices {
+ return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
+ }
+
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
+ }
+
+ lower := strings.ToLower(trimmed)
+ if strings.HasPrefix(lower, ""
+ }
+ if len(summary) > maxEasypayErrorSummary {
+ return summary[:maxEasypayErrorSummary] + "..."
+ }
+ return summary
}
func (e *EasyPay) resolveCID(paymentType string) string {
@@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
+ body, _, err := e.postRaw(ctx, endpoint, params)
+ return body, err
+}
+
+func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
- return nil, err
+ return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- resp, err := e.httpClient.Do(req)
+ client := e.httpClient
+ if client == nil {
+ client = &http.Client{Timeout: easypayHTTPTimeout}
+ }
+ resp, err := client.Do(req)
if err != nil {
- return nil, err
+ return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
- return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
+ if err != nil {
+ return nil, resp.StatusCode, err
+ }
+ return body, resp.StatusCode, nil
}
func easyPaySign(params map[string]string, pkey string) string {
diff --git a/backend/internal/payment/provider/easypay_refund_test.go b/backend/internal/payment/provider/easypay_refund_test.go
new file mode 100644
index 00000000..9e0e4942
--- /dev/null
+++ b/backend/internal/payment/provider/easypay_refund_test.go
@@ -0,0 +1,196 @@
+package provider
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestNormalizeEasyPayAPIBase(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input string
+ want string
+ }{
+ {input: "https://zpayz.cn", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ t.Parallel()
+ if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
+ t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
+ t.Parallel()
+
+ var gotPath string
+ var gotQuery url.Values
+ var gotForm url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotQuery = r.URL.Query()
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForm = r.PostForm
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess {
+ t.Fatalf("Refund response = %+v, want success", resp)
+ }
+ if gotPath != "/api.php" {
+ t.Fatalf("refund path = %q, want /api.php", gotPath)
+ }
+ if gotQuery.Get("act") != "refund" {
+ t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
+ }
+ for key, want := range map[string]string{
+ "pid": "pid-1",
+ "key": "pkey-1",
+ "out_trade_no": "out-456",
+ "money": "1.50",
+ } {
+ if got := gotForm.Get(key); got != want {
+ t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
+ }
+ }
+ if got := gotForm.Get("trade_no"); got != "" {
+ t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
+ }
+}
+
+func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
+ t.Parallel()
+
+ var gotForms []url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api.php" {
+ t.Errorf("refund path = %q, want /api.php", r.URL.Path)
+ }
+ if r.URL.Query().Get("act") != "refund" {
+ t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
+ }
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForms = append(gotForms, r.PostForm)
+ w.Header().Set("Content-Type", "application/json")
+ if len(gotForms) == 1 {
+ _, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
+ return
+ }
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
+ t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
+ }
+ if len(gotForms) != 2 {
+ t.Fatalf("refund attempts = %d, want 2", len(gotForms))
+ }
+ if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
+ t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[0].Get("trade_no"); got != "" {
+ t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[1].Get("trade_no"); got != "trade-123" {
+ t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
+ }
+ if got := gotForms[1].Get("out_trade_no"); got != "" {
+ t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
+ }
+}
+
+func TestEasyPayRefundResponseErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ statusCode int
+ body string
+ want string
+ }{
+ {name: "html response", statusCode: http.StatusOK, body: "bad config", want: "non-JSON response (HTTP 200): bad config"},
+ {name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
+ {name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
+ {name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): "},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(tt.statusCode)
+ _, _ = w.Write([]byte(tt.body))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL)
+ _, err := provider.Refund(context.Background(), payment.RefundRequest{
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err == nil {
+ t.Fatal("Refund returned nil error")
+ }
+ if !strings.Contains(err.Error(), tt.want) {
+ t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
+ }
+ })
+ }
+}
+
+func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
+ t.Helper()
+
+ provider, err := NewEasyPay("test-instance", map[string]string{
+ "pid": "pid-1",
+ "pkey": "pkey-1",
+ "apiBase": apiBase,
+ "notifyUrl": "https://example.com/notify",
+ "returnUrl": "https://example.com/return",
+ })
+ if err != nil {
+ t.Fatalf("NewEasyPay: %v", err)
+ }
+ return provider
+}
diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go
index 095305c2..e8b25c2b 100644
--- a/backend/internal/pkg/apicompat/anthropic_responses_test.go
+++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go
@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
+func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "message",
+ Content: []ResponsesContentPart{
+ {Type: "output_text", Text: "Cached response"},
+ },
+ },
+ },
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 3318, anth.Usage.InputTokens)
+ assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, anth.Usage.OutputTokens)
+}
+
+func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached_clamp",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 100,
+ OutputTokens: 5,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 150,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 0, anth.Usage.InputTokens)
+ assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 5, anth.Usage.OutputTokens)
+}
+
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
@@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name)
+ assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
+}
+
+func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_read",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_read",
+ Name: "Read",
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.Equal(t, "tool_use", anth.Content[0].Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
+}
+
+func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_other",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_other",
+ Name: "Search",
+ Arguments: `{"query":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
@@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
+func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.completed",
+ Response: &ResponsesResponse{
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ },
+ }, state)
+
+ require.Len(t, events, 2)
+ assert.Equal(t, "message_delta", events[0].Type)
+ assert.Equal(t, 3318, events[0].Usage.InputTokens)
+ assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, events[0].Usage.OutputTokens)
+ assert.Equal(t, "message_stop", events[1].Type)
+}
+
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()
@@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
}
+func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 0,
+ Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
+ }, state)
+ require.Len(t, events, 1)
+ assert.Equal(t, "content_block_start", events[0].Type)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 0,
+ Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ assert.Len(t, events, 0)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.done",
+ OutputIndex: 0,
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ require.Len(t, events, 2)
+ assert.Equal(t, "content_block_delta", events[0].Type)
+ assert.Equal(t, "input_json_delta", events[0].Delta.Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
+ assert.Equal(t, "content_block_stop", events[1].Type)
+}
+
func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState()
@@ -835,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
- fn, ok := tc["function"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, "get_weather", fn["name"])
+ assert.Equal(t, "get_weather", tc["name"])
+ assert.NotContains(t, tc, "function")
+}
+
+func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
+ req := &ResponsesRequest{
+ Model: "gpt-5.2",
+ Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
+ ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
+ }
+
+ resp, err := ResponsesToAnthropicRequest(req)
+ require.NoError(t, err)
+
+ var tc map[string]string
+ require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
+ assert.Equal(t, "tool", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
+}
+
+func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
+ req := &ResponsesRequest{
+ Model: "gpt-5.2",
+ Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
+ ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
+ }
+
+ resp, err := ResponsesToAnthropicRequest(req)
+ require.NoError(t, err)
+
+ var tc map[string]string
+ require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
+ assert.Equal(t, "tool", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
}
// ---------------------------------------------------------------------------
diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go
index 485262e8..268f9f22 100644
--- a/backend/internal/pkg/apicompat/anthropic_to_responses.go
+++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go
@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto"
// {"type":"any"} → "required"
// {"type":"none"} → "none"
-// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
+// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct {
Type string `json:"type"`
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
return json.Marshal("none")
case "tool":
return json.Marshal(map[string]any{
- "type": "function",
- "function": map[string]string{"name": tc.Name},
+ "type": "function",
+ "name": tc.Name,
})
default:
// Pass through unknown types as-is
diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
index c140449a..35d42999 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
+ assert.NotContains(t, tc, "function")
}
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
index c2725406..64ef5781 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
//
// "auto" → "auto"
// "none" → "none"
-// {"name":"X"} → {"type":"function","function":{"name":"X"}}
+// {"name":"X"} → {"type":"function","name":"X"}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
return nil, err
}
return json.Marshal(map[string]any{
- "type": "function",
- "function": map[string]string{"name": obj.Name},
+ "type": "function",
+ "name": obj.Name,
})
}
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go
index 5409a0f4..489ed238 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go
@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
- Input: json.RawMessage(item.Arguments),
+ Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
- out.Usage = AnthropicUsage{
- InputTokens: resp.Usage.InputTokens,
- OutputTokens: resp.Usage.OutputTokens,
- }
- if resp.Usage.InputTokensDetails != nil {
- out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
- }
+ out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
return out
}
+func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
+ if usage == nil {
+ return AnthropicUsage{}
+ }
+
+ cachedTokens := 0
+ if usage.InputTokensDetails != nil {
+ cachedTokens = usage.InputTokensDetails.CachedTokens
+ }
+
+ inputTokens := usage.InputTokens - cachedTokens
+ if inputTokens < 0 {
+ inputTokens = 0
+ }
+
+ return AnthropicUsage{
+ InputTokens: inputTokens,
+ OutputTokens: usage.OutputTokens,
+ CacheReadInputTokens: cachedTokens,
+ }
+}
+
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
@@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
}
}
+func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
+ if name != "Read" || raw == "" {
+ return json.RawMessage(raw)
+ }
+
+ var input map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &input); err != nil {
+ return json.RawMessage(raw)
+ }
+
+ if pages, ok := input["pages"]; !ok || string(pages) != `""` {
+ return json.RawMessage(raw)
+ }
+
+ delete(input, "pages")
+ sanitized, err := json.Marshal(input)
+ if err != nil {
+ return json.RawMessage(raw)
+ }
+ return sanitized
+}
+
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
@@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
+ CurrentToolName string
+ CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
@@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
- return resToAnthHandleBlockDone(state)
+ return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
@@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
+ state.CurrentToolName = evt.Item.Name
+ state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
@@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil
}
+ if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
+ state.CurrentToolArgs += evt.Delta
+ return nil
+ }
+
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
@@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}}
}
+func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
+ if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
+ return resToAnthHandleBlockDone(state)
+ }
+
+ raw := evt.Arguments
+ if raw == "" {
+ raw = state.CurrentToolArgs
+ }
+ sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
+ if len(sanitized) == 0 {
+ return closeCurrentBlock(state)
+ }
+
+ idx := state.ContentBlockIndex
+ events := []AnthropicStreamEvent{{
+ Type: "content_block_delta",
+ Index: &idx,
+ Delta: &AnthropicDelta{
+ Type: "input_json_delta",
+ PartialJSON: string(sanitized),
+ },
+ }}
+ events = append(events, closeCurrentBlock(state)...)
+ return events
+}
+
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
@@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
- state.InputTokens = evt.Response.Usage.InputTokens
- state.OutputTokens = evt.Response.Usage.OutputTokens
- if evt.Response.Usage.InputTokensDetails != nil {
- state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
- }
+ usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
+ state.InputTokens = usage.InputTokens
+ state.OutputTokens = usage.OutputTokens
+ state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":
@@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
+ state.CurrentToolName = ""
+ state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
index 49426b88..8fa652f2 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
-// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
+// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
+// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function
var tc struct {
Type string `json:"type"`
+ Name string `json:"name"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
- if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
+ if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
+ name := strings.TrimSpace(tc.Name)
+ if name == "" {
+ name = strings.TrimSpace(tc.Function.Name)
+ }
+ if name == "" {
+ return raw, nil
+ }
return json.Marshal(map[string]string{
"type": "tool",
- "name": tc.Function.Name,
+ "name": name,
})
}
diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go
index 69e99dc5..cee12948 100644
--- a/backend/internal/pkg/httputil/body.go
+++ b/backend/internal/pkg/httputil/body.go
@@ -2,16 +2,28 @@ package httputil
import (
"bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "errors"
+ "fmt"
"io"
"net/http"
+ "strings"
+
+ "github.com/klauspost/compress/zstd"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
+ // maxDecompressedBodySize limits the decompressed request body to 64 MB
+ // to prevent decompression bomb attacks.
+ maxDecompressedBodySize = 64 << 20
)
-// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
+// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
+// on content length, transparently decoding any Content-Encoding the upstream
+// client used to compress the body (zstd, gzip, deflate).
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
- return buf.Bytes(), nil
+ raw := buf.Bytes()
+
+ enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
+ if enc == "" || enc == "identity" {
+ return raw, nil
+ }
+
+ decoded, err := decompressRequestBody(enc, raw)
+ if err != nil {
+ return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
+ }
+
+ req.Header.Del("Content-Encoding")
+ req.Header.Del("Content-Length")
+ req.ContentLength = int64(len(decoded))
+
+ return decoded, nil
+}
+
+func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
+ switch encoding {
+ case "zstd":
+ dec, err := zstd.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer dec.Close()
+ return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
+ case "gzip", "x-gzip":
+ gr, err := gzip.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = gr.Close() }()
+ return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
+ case "deflate":
+ zr, err := zlib.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = zr.Close() }()
+ return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
+ default:
+ return nil, errors.New("unsupported Content-Encoding")
+ }
}
diff --git a/backend/internal/pkg/httputil/body_test.go b/backend/internal/pkg/httputil/body_test.go
new file mode 100644
index 00000000..ed8355d5
--- /dev/null
+++ b/backend/internal/pkg/httputil/body_test.go
@@ -0,0 +1,143 @@
+package httputil
+
+import (
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/klauspost/compress/zstd"
+)
+
+const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
+
+func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
+ t.Helper()
+ req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
+ if err != nil {
+ t.Fatalf("NewRequest: %v", err)
+ }
+ if encoding != "" {
+ req.Header.Set("Content-Encoding", encoding)
+ }
+ req.ContentLength = int64(len(body))
+ return req
+}
+
+func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
+ enc, _ := zstd.NewWriter(nil)
+ compressed := enc.EncodeAll([]byte(samplePayload), nil)
+ _ = enc.Close()
+
+ req := newRequestWithBody(t, compressed, "zstd")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+ if req.Header.Get("Content-Encoding") != "" {
+ t.Fatalf("Content-Encoding should be cleared after decoding")
+ }
+ if req.ContentLength != int64(len(samplePayload)) {
+ t.Fatalf("ContentLength not updated: %d", req.ContentLength)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
+ var buf bytes.Buffer
+ gw := gzip.NewWriter(&buf)
+ if _, err := gw.Write([]byte(samplePayload)); err != nil {
+ t.Fatalf("gzip write: %v", err)
+ }
+ if err := gw.Close(); err != nil {
+ t.Fatalf("gzip close: %v", err)
+ }
+
+ req := newRequestWithBody(t, buf.Bytes(), "gzip")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
+ var buf bytes.Buffer
+ zw := zlib.NewWriter(&buf)
+ if _, err := zw.Write([]byte(samplePayload)); err != nil {
+ t.Fatalf("zlib write: %v", err)
+ }
+ if err := zw.Close(); err != nil {
+ t.Fatalf("zlib close: %v", err)
+ }
+
+ req := newRequestWithBody(t, buf.Bytes(), "deflate")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "br")
+ _, err := ReadRequestBodyWithPrealloc(req)
+ if err == nil {
+ t.Fatal("expected error for unsupported encoding, got nil")
+ }
+ if !strings.Contains(err.Error(), "br") {
+ t.Fatalf("error should mention encoding, got %v", err)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
+ req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
+ _, err := ReadRequestBodyWithPrealloc(req)
+ if err == nil {
+ t.Fatal("expected error for corrupt zstd body, got nil")
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
+ req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
+ if err != nil {
+ t.Fatalf("NewRequest: %v", err)
+ }
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != nil {
+ t.Fatalf("expected nil body, got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "identity")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index b249bb61..d1cea9eb 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
return true, nil
}
+func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
+ return nil
+}
+
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go
index e3dd56b8..ef89e5b6 100644
--- a/backend/internal/repository/affiliate_repo.go
+++ b/backend/internal/repository/affiliate_repo.go
@@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID
return bound, nil
}
-func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
+func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
if amount <= 0 {
return false, nil
}
var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
- res, err := txClient.ExecContext(txCtx,
- "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
- amount, inviterID,
- )
+ // freezeHours > 0: add to frozen quota; == 0: add to available quota directly
+ var updateSQL string
+ if freezeHours > 0 {
+ updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ } else {
+ updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ }
+ res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
if err != nil {
return err
}
@@ -106,10 +110,19 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite
return nil
}
- if _, err = txClient.ExecContext(txCtx, `
+ if freezeHours > 0 {
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
+ inviterID, amount, inviteeUserID, freezeHours); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+ } else {
+ if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
- return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
}
applied = true
@@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID);
return applied, nil
}
+func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx,
+ `SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
+ inviterID, inviteeUserID)
+ if err != nil {
+ return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+ var total float64
+ if rows.Next() {
+ if err := rows.Scan(&total); err != nil {
+ return 0, err
+ }
+ }
+ return total, rows.Close()
+}
+
+func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
+ var thawed float64
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ var err error
+ thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
+ return err
+ })
+ return thawed, err
+}
+
+// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
+func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
+ rows, err := txClient.QueryContext(txCtx, `
+WITH matured AS (
+ UPDATE user_affiliate_ledger
+ SET frozen_until = NULL, updated_at = NOW()
+ WHERE user_id = $1
+ AND frozen_until IS NOT NULL
+ AND frozen_until <= NOW()
+ RETURNING amount
+)
+SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
+ if err != nil {
+ return 0, fmt.Errorf("thaw frozen quota: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var thawed float64
+ if rows.Next() {
+ if err := rows.Scan(&thawed); err != nil {
+ return 0, err
+ }
+ }
+ if err := rows.Close(); err != nil {
+ return 0, err
+ }
+ if thawed <= 0 {
+ return 0, nil
+ }
+
+ _, err = txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_quota = aff_quota + $1,
+ aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
+ updated_at = NOW()
+WHERE user_id = $2`, thawed, userID)
+ if err != nil {
+ return 0, fmt.Errorf("move thawed quota: %w", err)
+ }
+ return thawed, nil
+}
+
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64
var newBalance float64
@@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID
return err
}
+ // Thaw any matured frozen quota before transfer.
+ if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
+ return fmt.Errorf("thaw before transfer: %w", err)
+ }
+
rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS (
SELECT aff_quota::double precision AS amount
@@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64,
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
- ua.created_at
+ ua.created_at,
+ COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id
+LEFT JOIN user_affiliate_ledger ual
+ ON ual.user_id = $1
+ AND ual.source_user_id = ua.user_id
+ AND ual.action = 'accrue'
WHERE ua.inviter_id = $1
+GROUP BY ua.user_id, u.email, u.username, ua.created_at
ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit)
if err != nil {
@@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit)
for rows.Next() {
var item service.AffiliateInvitee
var createdAt time.Time
- if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
+ if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
return nil, err
}
item.CreatedAt = &createdAt
@@ -299,6 +393,7 @@ SELECT user_id,
inviter_id,
aff_count,
aff_quota::double precision,
+ aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
@@ -326,6 +421,7 @@ WHERE user_id = $1`, userID)
&inviterID,
&out.AffCount,
&out.AffQuota,
+ &out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
@@ -351,6 +447,7 @@ SELECT user_id,
inviter_id,
aff_count,
aff_quota::double precision,
+ aff_frozen_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
@@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
&inviterID,
&out.AffCount,
&out.AffQuota,
+ &out.AffFrozenQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go
index 369f57cf..697a193b 100644
--- a/backend/internal/repository/affiliate_repo_integration_test.go
+++ b/backend/internal/repository/affiliate_repo_integration_test.go
@@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
require.NoError(t, err)
require.True(t, bound, "invitee must bind to inviter")
- applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
+ applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
require.NoError(t, err)
require.True(t, applied, "AccrueQuota must report applied=true")
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index b1070aba..f1a42ef7 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -24,6 +24,49 @@ const (
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
+
+ // snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
+ // 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
+ snapshotGraceTTLSeconds = 60
+)
+
+var (
+ // activateSnapshotScript 原子 CAS 切换快照版本。
+ // 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
+ // 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
+ //
+ // KEYS[1] = activeKey (sched:active:{bucket})
+ // KEYS[2] = readyKey (sched:ready:{bucket})
+ // KEYS[3] = bucketSetKey (sched:buckets)
+ // KEYS[4] = snapshotKey (新写入的快照 key)
+ // ARGV[1] = 新版本号字符串
+ // ARGV[2] = bucket 字符串 (用于 SADD)
+ // ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
+ // ARGV[4] = 宽限期 TTL 秒数
+ //
+ // 返回 1 = 已激活, 0 = 版本过旧未激活
+ activateSnapshotScript = redis.NewScript(`
+local currentActive = redis.call('GET', KEYS[1])
+local newVersion = tonumber(ARGV[1])
+
+if currentActive ~= false then
+ local curVersion = tonumber(currentActive)
+ if curVersion and newVersion < curVersion then
+ redis.call('DEL', KEYS[4])
+ return 0
+ end
+end
+
+redis.call('SET', KEYS[1], ARGV[1])
+redis.call('SET', KEYS[2], '1')
+redis.call('SADD', KEYS[3], ARGV[2])
+
+if currentActive ~= false and currentActive ~= ARGV[1] then
+ redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
+end
+
+return 1
+`)
)
type schedulerCache struct {
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
- activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
- oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
-
+ // Phase 1: 分配新版本号并写入快照数据。
+ // INCR 保证每个调用方获得唯一递增版本号。
+ // 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
return err
}
- pipe := c.rdb.Pipeline()
if len(accounts) > 0 {
// 使用序号作为 score,保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10),
})
}
+ pipe := c.rdb.Pipeline()
for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
- } else {
- pipe.Del(ctx, snapshotKey)
- }
- pipe.Set(ctx, activeKey, versionStr, 0)
- pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
- pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
- if _, err := pipe.Exec(ctx); err != nil {
- return err
+ if _, err := pipe.Exec(ctx); err != nil {
+ return err
+ }
}
- if oldActive != "" && oldActive != versionStr {
- _ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
+ // Phase 2: 原子 CAS 激活版本。
+ // Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
+ // 防止并发写入导致版本回滚。
+ // 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
+ activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
+ readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
+ snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
+
+ keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
+ args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
+
+ _, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
+ if err != nil {
+ return err
}
return nil
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
+func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
+ key := schedulerBucketKey(schedulerLockPrefix, bucket)
+ return c.rdb.Del(ctx, key).Err()
+}
+
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
@@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
+ AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
+ GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
+func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
+ if len(accountGroups) == 0 {
+ return nil
+ }
+
+ filtered := make([]service.AccountGroup, 0, len(accountGroups))
+ for _, ag := range accountGroups {
+ if ag.GroupID <= 0 {
+ continue
+ }
+ filtered = append(filtered, service.AccountGroup{
+ AccountID: ag.AccountID,
+ GroupID: ag.GroupID,
+ Priority: ag.Priority,
+ CreatedAt: ag.CreatedAt,
+ })
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
+
+func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
+ if len(groupIDs) == 0 && len(accountGroups) == 0 {
+ return nil
+ }
+
+ seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
+ filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
+ for _, id := range groupIDs {
+ if id <= 0 {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ filtered = append(filtered, id)
+ }
+ for _, ag := range accountGroups {
+ if ag.GroupID <= 0 {
+ continue
+ }
+ if _, ok := seen[ag.GroupID]; ok {
+ continue
+ }
+ seen[ag.GroupID] = struct{}{}
+ filtered = append(filtered, ag.GroupID)
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
+
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go
index 134a6a07..948c2c73 100644
--- a/backend/internal/repository/scheduler_cache_integration_test.go
+++ b/backend/internal/repository/scheduler_cache_integration_test.go
@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
SessionWindowStart: &now,
SessionWindowEnd: &windowEnd,
SessionWindowStatus: "active",
+ GroupIDs: []int64{bucket.GroupID},
+ AccountGroups: []service.AccountGroup{
+ {
+ AccountID: 101,
+ GroupID: bucket.GroupID,
+ Priority: 5,
+ Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
+ },
+ },
}
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
require.Equal(t, 4, got.GetMaxSessions())
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
require.Nil(t, got.Extra["unused_large_field"])
+ require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
+ require.Len(t, got.AccountGroups, 1)
+ require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
+ require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
+ require.Nil(t, got.AccountGroups[0].Group)
full, err := cache.GetAccount(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, full)
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
+ require.Len(t, full.AccountGroups, 1)
+ require.NotNil(t, full.AccountGroups[0].Group)
}
diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go
index d302c5ea..32dda0a8 100644
--- a/backend/internal/repository/scheduler_cache_unit_test.go
+++ b/backend/internal/repository/scheduler_cache_unit_test.go
@@ -56,3 +56,43 @@ func TestBuildSchedulerMetadataAccount_KeepsModelRateLimits(t *testing.T) {
require.Equal(t, modelLimits, got.Extra["model_rate_limits"], "model_rate_limits must be carried into scheduler snapshot for rate-limit-aware selection")
require.Nil(t, got.Extra["unused_large_field"])
}
+
+func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
+ account := service.Account{
+ ID: 42,
+ Platform: service.PlatformAnthropic,
+ GroupIDs: []int64{7, 9, 7, 0},
+ AccountGroups: []service.AccountGroup{
+ {
+ AccountID: 42,
+ GroupID: 7,
+ Priority: 2,
+ Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
+ Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
+ },
+ {
+ AccountID: 42,
+ GroupID: 11,
+ Priority: 3,
+ Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
+ },
+ {
+ AccountID: 42,
+ GroupID: 0,
+ Priority: 4,
+ },
+ },
+ }
+
+ got := buildSchedulerMetadataAccount(account)
+
+ require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
+ require.Len(t, got.AccountGroups, 2)
+ require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
+ require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
+ require.Equal(t, 2, got.AccountGroups[0].Priority)
+ require.Nil(t, got.AccountGroups[0].Account)
+ require.Nil(t, got.AccountGroups[0].Group)
+ require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
+ require.Nil(t, got.Groups)
+}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index d605c52b..fabf3b5d 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 5,
"default_balance": 1.25,
"affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -737,6 +740,7 @@ func TestAPIContracts(t *testing.T) {
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_cch_signing": false,
+ "enable_anthropic_cache_ttl_1h_injection": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
@@ -745,6 +749,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
+ "openai_fast_policy_settings": {
+ "rules": [
+ {
+ "service_tier": "priority",
+ "action": "filter",
+ "scope": "all",
+ "fallback_action": "pass"
+ }
+ ]
+ },
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -898,6 +912,9 @@ func TestAPIContracts(t *testing.T) {
"default_concurrency": 0,
"default_balance": 0,
"affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
"default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
@@ -918,12 +935,23 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
+ "enable_anthropic_cache_ttl_1h_injection": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
+ "openai_fast_policy_settings": {
+ "rules": [
+ {
+ "service_tier": "priority",
+ "action": "filter",
+ "scope": "all",
+ "fallback_action": "pass"
+ }
+ ]
+ },
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index ce2a5dbe..cab0215b 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -66,6 +66,7 @@ func isOpenAIImageModel(model string) bool {
type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
+ claudeTokenProvider *ClaudeTokenProvider
antigravityGatewayService *AntigravityGatewayService
windsurfChatService *WindsurfChatService
httpUpstream HTTPUpstream
@@ -77,6 +78,7 @@ type AccountTestService struct {
func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
+ claudeTokenProvider *ClaudeTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
windsurfChatService *WindsurfChatService,
httpUpstream HTTPUpstream,
@@ -86,6 +88,7 @@ func NewAccountTestService(
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
+ claudeTokenProvider: claudeTokenProvider,
antigravityGatewayService: antigravityGatewayService,
windsurfChatService: windsurfChatService,
httpUpstream: httpUpstream,
@@ -223,6 +226,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID, prompt)
}
+ if account.Type == AccountTypeServiceAccount {
+ return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
+ }
// Determine authentication method and API URL
var authToken string
@@ -326,6 +332,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
+func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
+ if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
+ testModelID = mappedModel
+ } else {
+ testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ payload, err := createTestPayload(testModelID, "")
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create test payload")
+ }
+ payloadBytes, _ := json.Marshal(payload)
+ vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
+ }
+
+ if s.claudeTokenProvider == nil {
+ return s.sendErrorAndEnd(c, "Claude token provider not configured")
+ }
+ accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
+ }
+
+ fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
+ if resp.StatusCode == http.StatusForbidden {
+ _ = s.accountRepo.SetError(ctx, account.ID, errMsg)
+ }
+ return s.sendErrorAndEnd(c, errMsg)
+ }
+
+ return s.processClaudeStream(c, resp.Body)
+}
+
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string, prompt string) error {
region := bedrockRuntimeRegion(account)
@@ -728,8 +802,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID = geminicli.DefaultTestModel
}
- // For API Key accounts with model mapping, map the model
- if account.Type == AccountTypeAPIKey {
+ // For static upstream credentials with model mapping, map the model
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -757,6 +831,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
+ case AccountTypeServiceAccount:
+ req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -948,6 +1024,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
}
+func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
+ if s.geminiTokenProvider == nil {
+ return nil, fmt.Errorf("gemini token provider not configured")
+ }
+ accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get service account access token: %w", err)
+ }
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ return req, nil
+}
+
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
var inner map[string]any
@@ -1286,7 +1383,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
- apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
+ apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go
index 80a2fc31..257159c4 100644
--- a/backend/internal/service/account_test_service_openai_image_test.go
+++ b/backend/internal/service/account_test_service_openai_image_test.go
@@ -8,6 +8,7 @@ import (
"strings"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes
require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
require.Contains(t, rec.Body.String(), "\"success\":true")
}
+
+func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: &config.Config{},
+ }
+ account := &Account{
+ ID: 54,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1",
+ },
+ }
+
+ err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
+ require.NoError(t, err)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
+ require.Contains(t, rec.Body.String(), "\"success\":true")
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 147ee4e1..b854c16e 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"sort"
+ "strconv"
"strings"
"time"
@@ -58,6 +59,7 @@ type AdminService interface {
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
+ AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
@@ -291,6 +293,7 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
+ Filters *BulkUpdateAccountFilters
Name string
ProxyID *int64
Concurrency *int
@@ -307,6 +310,15 @@ type BulkUpdateAccountsInput struct {
SkipMixedChannelCheck bool
}
+type BulkUpdateAccountFilters struct {
+ Platform string
+ Type string
+ Status string
+ Group string
+ Search string
+ PrivacyMode string
+}
+
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
@@ -1961,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil
}
+// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
+func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
+ apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
+ if err != nil {
+ return nil, err
+ }
+ apiKey.Usage5h = 0
+ apiKey.Usage1d = 0
+ apiKey.Usage7d = 0
+ apiKey.Window5hStart = nil
+ apiKey.Window1dStart = nil
+ apiKey.Window7dStart = nil
+ if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
+ return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
+ }
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+ if s.billingCacheService != nil {
+ _ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
+ }
+ return apiKey, nil
+}
+
// ReplaceUserGroup 替换用户的专属分组
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
if oldGroupID == newGroupID {
@@ -2286,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
+ if len(input.AccountIDs) == 0 && input.Filters != nil {
+ accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
+ if err != nil {
+ return nil, err
+ }
+ input.AccountIDs = accountIDs
+ }
+
result := &BulkUpdateAccountsResult{
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
@@ -2401,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
+func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
+ if filters == nil {
+ return nil, nil
+ }
+
+ groupID := int64(0)
+ switch strings.TrimSpace(filters.Group) {
+ case "":
+ case "ungrouped":
+ groupID = AccountListGroupUngrouped
+ default:
+ parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid group filter: %w", err)
+ }
+ groupID = parsedGroupID
+ }
+
+ const pageSize = 500
+ page := 1
+ accountIDs := make([]int64, 0, pageSize)
+
+ for {
+ accounts, total, err := s.ListAccounts(
+ ctx,
+ page,
+ pageSize,
+ filters.Platform,
+ filters.Type,
+ filters.Status,
+ filters.Search,
+ groupID,
+ filters.PrivacyMode,
+ "",
+ "",
+ )
+ if err != nil {
+ return nil, err
+ }
+ for _, account := range accounts {
+ accountIDs = append(accountIDs, account.ID)
+ }
+ if int64(len(accountIDs)) >= total || len(accounts) == 0 {
+ return accountIDs, nil
+ }
+ page++
+ }
+}
+
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index 4845d87c..df415295 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
+ "reflect"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
+ listData []Account
+ listResult *pagination.PaginationResult
+ listErr error
+ listCalled bool
+ lastListParams pagination.PaginationParams
+ lastListFilters struct {
+ platform string
+ accountType string
+ status string
+ search string
+ groupID int64
+ privacyMode string
+ }
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
return nil, nil
}
+func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
+ s.listCalled = true
+ s.lastListParams = params
+ s.lastListFilters.platform = platform
+ s.lastListFilters.accountType = accountType
+ s.lastListFilters.status = status
+ s.lastListFilters.search = search
+ s.lastListFilters.groupID = groupID
+ s.lastListFilters.privacyMode = privacyMode
+ if s.listErr != nil {
+ return nil, nil, s.listErr
+ }
+ if s.listResult != nil {
+ return s.listData, s.listResult, nil
+ }
+ return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
+}
+
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
// No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls)
}
+
+func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ listData: []Account{
+ {ID: 7},
+ {ID: 11},
+ },
+ listResult: &pagination.PaginationResult{Total: 2},
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ schedulable := true
+ input := &BulkUpdateAccountsInput{
+ Schedulable: &schedulable,
+ }
+
+ filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
+ require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
+ require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
+
+ filtersValue := reflect.New(filtersField.Type().Elem())
+ filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
+ filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
+ filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
+ filtersValue.Elem().FieldByName("Group").SetString("12")
+ filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
+ filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
+ filtersField.Set(filtersValue)
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
+ require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
+ require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
+ require.Equal(t, StatusActive, repo.lastListFilters.status)
+ require.Equal(t, "bulk-target", repo.lastListFilters.search)
+ require.Equal(t, int64(12), repo.lastListFilters.groupID)
+ require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
+ require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
+ require.Equal(t, 2, result.Success)
+ require.Equal(t, 0, result.Failed)
+ require.Equal(t, []int64{7, 11}, result.SuccessIDs)
+}
diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go
index aca32076..5a4e91e7 100644
--- a/backend/internal/service/affiliate_service.go
+++ b/backend/internal/service/affiliate_service.go
@@ -65,16 +65,18 @@ type AffiliateSummary struct {
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AffiliateInvitee struct {
- UserID int64 `json:"user_id"`
- Email string `json:"email"`
- Username string `json:"username"`
- CreatedAt *time.Time `json:"created_at,omitempty"`
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ CreatedAt *time.Time `json:"created_at,omitempty"`
+ TotalRebate float64 `json:"total_rebate"`
}
type AffiliateDetail struct {
@@ -83,6 +85,7 @@ type AffiliateDetail struct {
InviterID *int64 `json:"inviter_id,omitempty"`
AffCount int `json:"aff_count"`
AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
AffHistoryQuota float64 `json:"aff_history_quota"`
// EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
// 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
@@ -95,7 +98,9 @@ type AffiliateRepository interface {
EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
- AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error)
+ AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
+ GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
+ ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
@@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64
}
func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
+ // Lazy thaw: move any matured frozen quota to available before reading.
+ if s != nil && s.repo != nil {
+ // best-effort: thaw failure is non-fatal
+ _, _ = s.repo.ThawFrozenQuota(ctx, userID)
+ }
+
summary, err := s.EnsureUserAffiliate(ctx, userID)
if err != nil {
return nil, err
@@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64)
InviterID: summary.InviterID,
AffCount: summary.AffCount,
AffQuota: summary.AffQuota,
+ AffFrozenQuota: summary.AffFrozenQuota,
AffHistoryQuota: summary.AffHistoryQuota,
EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
Invitees: invitees,
@@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID
if err != nil {
return 0, err
}
+ // 有效期检查:超过返利有效期后不再产生返利
+ if s.settingService != nil {
+ if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
+ if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
+ return 0, nil
+ }
+ }
+ }
+
rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
if rebate <= 0 {
return 0, nil
}
- applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate)
+ // 单人上限检查:精确截断到剩余额度
+ if s.settingService != nil {
+ if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
+ existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if existing >= perInviteeCap {
+ return 0, nil
+ }
+ if remaining := perInviteeCap - existing; rebate > remaining {
+ rebate = roundTo(remaining, 8)
+ }
+ }
+ }
+
+ var freezeHours int
+ if s.settingService != nil {
+ freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
+ }
+
+ applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
if err != nil {
return 0, err
}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index a18cf39c..9815f31b 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
user *User,
invitationCode string,
signupSource string,
+ affiliateCode string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
@@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 08b0f4b7..b1adf071 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
-func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
+// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
+func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
}
}
+// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
+// for an OAuth-registered user. Failures are logged but never block registration.
+func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
+ if s.affiliateService == nil || userID <= 0 {
+ return
+ }
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
+ }
+ if code := strings.TrimSpace(affiliateCode); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
+ }
+ }
+}
+
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index c1ad6240..acc44a38 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
- tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "")
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
- tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "")
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index 4e695eb9..050db55b 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil
}
+// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
+func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
+ if s.cache == nil {
+ return nil
+ }
+ if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
+ logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
+ return err
+ }
+ return nil
+}
+
// ============================================
// API Key 限速缓存方法
// ============================================
diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go
index 82fa31c4..d70379c1 100644
--- a/backend/internal/service/claude_token_provider.go
+++ b/backend/internal/service/claude_token_provider.go
@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
-// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
+// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
- if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
- return "", errors.New("not an anthropic oauth account")
+ if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
+ return "", errors.New("not an anthropic oauth or service account")
+ }
+ if account.Type == AccountTypeServiceAccount {
+ return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
+
+func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
+ return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
+}
diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go
index 3e21f6f4..d4a4a14a 100644
--- a/backend/internal/service/claude_token_provider_test.go
+++ b/backend/internal/service/claude_token_provider_test.go
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
- return "", errors.New("not an anthropic oauth account")
+ return "", errors.New("not an anthropic oauth or service account")
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 1c8e7cc9..9793b6b2 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -20,10 +20,15 @@ const (
// Affiliate rebate settings
const (
- AffiliateRebateRateDefault = 20.0
- AffiliateRebateRateMin = 0.0
- AffiliateRebateRateMax = 100.0
- AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
+ AffiliateRebateRateDefault = 20.0
+ AffiliateRebateRateMin = 0.0
+ AffiliateRebateRateMax = 100.0
+ AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
+ AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
+ AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
+ AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
+ AffiliateRebateDurationDaysMax = 3650 // ~10 年
+ AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
)
// Platform constants
@@ -37,11 +42,12 @@ const (
// Account type constants
const (
- AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
- AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
- AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
- AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
- AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
+ AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
+ AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
+ AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
+ AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
@@ -98,6 +104,9 @@ const (
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
+ SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
+ SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
+ SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -299,6 +308,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
+ // SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
+ // service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
+ // targets OpenAI's body-level service_tier field instead of Claude's
+ // anthropic-beta header.
+ SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
+
// =========================
// Claude Code Version Check
// =========================
@@ -322,6 +337,8 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
+ // SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
+ SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
diff --git a/backend/internal/service/gateway_anthropic_vertex_service_account_test.go b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go
new file mode 100644
index 00000000..aa779805
--- /dev/null
+++ b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go
@@ -0,0 +1,68 @@
+package service
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+ c.Request.Header.Set("Authorization", "Bearer inbound-token")
+ c.Request.Header.Set("X-Api-Key", "inbound-api-key")
+ c.Request.Header.Set("Anthropic-Version", "2023-06-01")
+ c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
+
+ account := &Account{
+ ID: 301,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeServiceAccount,
+ Credentials: map[string]any{
+ "project_id": "vertex-proj",
+ "location": "us-east5",
+ },
+ }
+ body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
+
+ svc := &GatewayService{}
+ req, err := svc.buildUpstreamRequest(
+ context.Background(),
+ c,
+ account,
+ body,
+ "vertex-token",
+ "service_account",
+ "claude-sonnet-4-5@20250929",
+ false,
+ false,
+ )
+ require.NoError(t, err)
+ require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
+ require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
+ require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
+ require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
+ require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
+
+ got := readRequestBodyForTest(t, req)
+ require.Equal(t, "", gjson.GetBytes(got, "model").String())
+ require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
+ require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
+}
+
+func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
+ t.Helper()
+ require.NotNil(t, req.Body)
+ body, err := io.ReadAll(req.Body)
+ require.NoError(t, err)
+ return body
+}
diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go
index e6c9de7d..e0c3cafd 100644
--- a/backend/internal/service/gateway_body_order_test.go
+++ b/backend/internal/service/gateway_body_order_test.go
@@ -1,13 +1,91 @@
package service
import (
+ "context"
+ "errors"
"strings"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
+type gatewayTTLSettingRepo struct {
+ data map[string]string
+}
+
+func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
+ return nil, ErrSettingNotFound
+}
+
+func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
+ if r == nil {
+ return "", ErrSettingNotFound
+ }
+ v, ok := r.data[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return v, nil
+}
+
+func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
+ if r == nil {
+ return errors.New("setting repo is nil")
+ }
+ if r.data == nil {
+ r.data = map[string]string{}
+ }
+ r.data[key] = value
+ return nil
+}
+
+func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string)
+ if r == nil {
+ return result, nil
+ }
+ for _, key := range keys {
+ if v, ok := r.data[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
+}
+
+func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
+ if r == nil {
+ return errors.New("setting repo is nil")
+ }
+ if r.data == nil {
+ r.data = map[string]string{}
+ }
+ for key, value := range settings {
+ r.data[key] = value
+ }
+ return nil
+}
+
+func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string)
+ if r == nil {
+ return result, nil
+ }
+ for key, value := range r.data {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
+ if r != nil {
+ delete(r.data, key)
+ }
+ return nil
+}
+
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
@@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}
+
+func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
+ body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
+
+ result := injectAnthropicCacheControlTTL1h(body)
+ resultStr := string(result)
+
+ assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
+ require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
+ require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
+ require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
+ require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
+ require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
+ require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
+}
+
+func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
+ repo := &gatewayTTLSettingRepo{data: map[string]string{
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
+ }}
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ svc := &GatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+ account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
+
+ target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
+ require.True(t, ok)
+ require.Equal(t, cacheTTLTarget5m, target)
+
+ account.Extra = map[string]any{
+ "cache_ttl_override_enabled": true,
+ "cache_ttl_override_target": "1h",
+ }
+ target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
+ require.True(t, ok)
+ require.Equal(t, cacheTTLTarget1h, target)
+}
+
+func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
+ repo := &gatewayTTLSettingRepo{data: map[string]string{
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
+ }}
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ svc := &GatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+
+ require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
+ require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
+
+ repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
+}
diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go
index c531667e..7ac77f77 100644
--- a/backend/internal/service/gateway_forward_as_chat_completions.go
+++ b/backend/internal/service/gateway_forward_as_chat_completions.go
@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
- if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
+ if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
+ if normalized != originalModel {
+ mappedModel = normalized
+ }
+ } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go
index 647193d6..8f8a1e94 100644
--- a/backend/internal/service/gateway_forward_as_responses.go
+++ b/backend/internal/service/gateway_forward_as_responses.go
@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
- if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
+ if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
+ if normalized != originalModel {
+ mappedModel = normalized
+ }
+ } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 701e56b2..319d4d0c 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -11,6 +11,7 @@ import (
"io"
"log/slog"
mathrand "math/rand"
+ "net"
"net/http"
"net/url"
"os"
@@ -20,6 +21,7 @@ import (
"strconv"
"strings"
"sync/atomic"
+ "syscall"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -63,6 +65,11 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
+const (
+ cacheTTLTarget5m = "5m"
+ cacheTTLTarget1h = "1h"
+)
+
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -335,9 +342,8 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
- sseDataRe = regexp.MustCompile(`^data:\s*`)
- claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
- claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`)
+ sseDataRe = regexp.MustCompile(`^data:\s*`)
+ claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -677,15 +683,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
- if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
+ uid := ParseMetadataUserID(parsed.MetadataUserID)
+ if uid != nil && uid.SessionID != "" {
+ slog.Info("sticky.hash_source",
+ "source", "metadata_user_id",
+ "session_id", uid.SessionID,
+ "device_id", uid.DeviceID,
+ "is_new_format", uid.IsNewFormat,
+ )
return uid.SessionID
}
+ slog.Info("sticky.hash_metadata_parse_failed",
+ "metadata_user_id", parsed.MetadataUserID,
+ "parsed_nil", uid == nil,
+ )
}
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
- return s.hashContent(cacheableContent)
+ hash := s.hashContent(cacheableContent)
+ slog.Info("sticky.hash_source",
+ "source", "cacheable_content",
+ "hash", hash,
+ )
+ return hash
}
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
@@ -725,7 +747,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
}
if combined.Len() > 0 {
- return s.hashContent(combined.String())
+ hash := s.hashContent(combined.String())
+ slog.Info("sticky.hash_source",
+ "source", "message_content_fallback",
+ "hash", hash,
+ "content_len", combined.Len(),
+ )
+ return hash
}
return ""
@@ -1432,14 +1460,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
var stickyAccountID int64
+ var stickySource string
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
+ stickySource = "prefetch"
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
+ stickySource = "cache"
}
}
+ // [DEBUG-STICKY] 调度器入口日志
+ slog.Info("sticky.scheduler_entry",
+ "group_id", derefGroupID(groupID),
+ "session_hash", shortSessionHash(sessionHash),
+ "sticky_account_id", stickyAccountID,
+ "sticky_source", stickySource,
+ "model", requestedModel,
+ "load_batch", cfg.LoadBatchEnabled,
+ "has_concurrency_svc", s.concurrencyService != nil,
+ "excluded_count", len(excludedIDs),
+ )
+
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
@@ -1615,6 +1658,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(routingCandidates) > 0 {
// 1.5. 在路由账号范围内检查粘性会话
if sessionHash != "" && stickyAccountID > 0 {
+ slog.Debug("sticky.layer1_5_checking",
+ "sticky_account_id", stickyAccountID,
+ "in_routing_list", containsInt64(routingAccountIDs, stickyAccountID),
+ "is_excluded", isExcluded(stickyAccountID),
+ "in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(),
+ "session", shortSessionHash(sessionHash),
+ )
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
@@ -1638,6 +1688,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
+ slog.Debug("sticky.layer1_5_hit",
+ "account_id", stickyAccountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "slot_acquired",
+ )
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
@@ -1788,27 +1843,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 检查账户是否需要清理粘性会话绑定
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
+ slog.Debug("sticky.layer1_5_no_routing_clear",
+ "account_id", accountID,
+ "reason", "should_clear_sticky_session",
+ "session", shortSessionHash(sessionHash),
+ )
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) &&
- s.isAccountAllowedForPlatform(account, platform, useMixed) &&
- (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
- s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
- s.isAccountSchedulableForQuota(account) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) &&
- s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
+ // 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的
+ // accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存
+ // 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。
+ platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed)
+ modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)
+ modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel)
+ quotaOK := s.isAccountSchedulableForQuota(account)
+ windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true)
+ rpmOK := s.isAccountSchedulableForRPM(ctx, account, true)
+ schedulable := s.isAccountSchedulableForSelection(account)
+
+ slog.Debug("sticky.layer1_5_no_routing_checks",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "clear_sticky", clearSticky,
+ "schedulable", schedulable,
+ "platform_ok", platformOK,
+ "model_supported", modelSupported,
+ "model_schedulable", modelSchedulable,
+ "quota_ok", quotaOK,
+ "window_cost_ok", windowCostOK,
+ "rpm_ok", rpmOK,
+ )
+
+ if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "session_limit",
+ "session", shortSessionHash(sessionHash),
+ )
} else {
+ slog.Debug("sticky.layer1_5_no_routing_hit",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "slot_acquired",
+ )
if s.cache != nil {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}
+ } else {
+ slog.Debug("sticky.layer1_5_no_routing_slot_busy",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ )
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@@ -1817,6 +1910,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
+ slog.Debug("sticky.layer1_5_no_routing_hit",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "wait_plan",
+ )
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
@@ -1825,12 +1923,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
})
}
}
+ } else if !clearSticky {
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "gate_check_failed",
+ "session", shortSessionHash(sessionHash),
+ )
}
+ } else {
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "account_not_in_map",
+ "session", shortSessionHash(sessionHash),
+ )
}
}
+ } else if len(routingAccountIDs) == 0 && sessionHash != "" {
+ slog.Debug("sticky.layer1_5_no_routing_skip",
+ "sticky_account_id", stickyAccountID,
+ "is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(),
+ "session", shortSessionHash(sessionHash),
+ "reason", func() string {
+ if stickyAccountID == 0 {
+ return "no_sticky_binding"
+ }
+ return "sticky_account_excluded"
+ }(),
+ )
}
// ============ Layer 2: 负载感知选择 ============
+ slog.Debug("sticky.layer2_fallback",
+ "session", shortSessionHash(sessionHash),
+ "sticky_account_id", stickyAccountID,
+ "reason", "sticky_not_used_falling_back_to_load_balance",
+ "total_accounts", len(accounts),
+ )
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
@@ -3654,7 +3782,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
- requestedModel = claude.NormalizeModelID(requestedModel)
+ if account.Type == AccountTypeServiceAccount {
+ requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
+ } else {
+ requestedModel = claude.NormalizeModelID(requestedModel)
+ }
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
@@ -3674,6 +3806,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
return apiKey, "apikey", nil
case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
+ case AccountTypeServiceAccount:
+ if account.Platform != PlatformAnthropic {
+ return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
+ }
+ if s.claudeTokenProvider == nil {
+ return "", "", errors.New("claude token provider not configured")
+ }
+ accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return "", "", err
+ }
+ return accessToken, "service_account", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -3781,16 +3925,6 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return ParseMetadataUserID(metadataUserID) != nil
}
-func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
- if IsClaudeCodeClient(ctx) {
- return true
- }
- if parsed == nil || c == nil {
- return false
- }
- return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
-}
-
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。
// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。
@@ -4153,6 +4287,87 @@ func enforceCacheControlLimit(body []byte) []byte {
return body
}
+// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
+// 仅修改已经存在的 cache_control,不新增缓存断点。
+func injectAnthropicCacheControlTTL1h(body []byte) []byte {
+ return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
+}
+
+func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
+ if len(body) == 0 || ttl == "" {
+ return body
+ }
+ out := body
+ var paths []string
+ addPath := func(path string, value gjson.Result) {
+ cc := value.Get("cache_control")
+ if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
+ return
+ }
+ if cc.Get("ttl").String() == ttl {
+ return
+ }
+ paths = append(paths, path+".cache_control.ttl")
+ }
+
+ if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
+ paths = append(paths, "cache_control.ttl")
+ }
+
+ system := gjson.GetBytes(body, "system")
+ if system.IsArray() {
+ idx := -1
+ system.ForEach(func(_, block gjson.Result) bool {
+ idx++
+ addPath(fmt.Sprintf("system.%d", idx), block)
+ return true
+ })
+ }
+
+ messages := gjson.GetBytes(body, "messages")
+ if messages.IsArray() {
+ msgIdx := -1
+ messages.ForEach(func(_, msg gjson.Result) bool {
+ msgIdx++
+ content := msg.Get("content")
+ if !content.IsArray() {
+ return true
+ }
+ contentIdx := -1
+ content.ForEach(func(_, block gjson.Result) bool {
+ contentIdx++
+ addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
+ return true
+ })
+ return true
+ })
+ }
+
+ tools := gjson.GetBytes(body, "tools")
+ if tools.IsArray() {
+ idx := -1
+ tools.ForEach(func(_, tool gjson.Result) bool {
+ idx++
+ addPath(fmt.Sprintf("tools.%d", idx), tool)
+ return true
+ })
+ }
+
+ for _, path := range paths {
+ if next, err := sjson.SetBytes(out, path, ttl); err == nil {
+ out = next
+ }
+ }
+ return out
+}
+
+func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
+ if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
+ return false
+ }
+ return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
+}
+
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -4286,6 +4501,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
mappingSource = "account"
}
}
+ if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ if candidate, matched := account.ResolveMappedModel(reqModel); matched {
+ mappedModel = candidate
+ mappingSource = "account"
+ } else {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
+ if normalized != reqModel {
+ mappedModel = normalized
+ mappingSource = "vertex"
+ }
+ }
+ }
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
@@ -4300,6 +4527,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
+ if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
+ body = injectAnthropicCacheControlTTL1h(body)
+ }
+
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -5799,6 +6030,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
+ if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
+ }
+
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -6010,6 +6245,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
+func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ token string,
+ modelID string,
+ reqStream bool,
+) (*http.Request, error) {
+ vertexBody, err := buildVertexAnthropicRequestBody(body)
+ if err != nil {
+ return nil, err
+ }
+ setOpsUpstreamRequestBody(c, vertexBody)
+ fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
+ if err != nil {
+ return nil, err
+ }
+
+ if c != nil && c.Request != nil {
+ for key, values := range c.Request.Header {
+ lowerKey := strings.ToLower(strings.TrimSpace(key))
+ if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
+ continue
+ }
+ wireKey := resolveWireCasing(key)
+ for _, v := range values {
+ addHeaderRaw(req.Header, wireKey, v)
+ }
+ }
+ }
+
+ req.Header.Del("authorization")
+ req.Header.Del("x-api-key")
+ req.Header.Del("x-goog-api-key")
+ req.Header.Del("cookie")
+ req.Header.Del("anthropic-version")
+ setHeaderRaw(req.Header, "authorization", "Bearer "+token)
+ setHeaderRaw(req.Header, "content-type", "application/json")
+
+ s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
+ "url": req.URL.String(),
+ "token_type": "service_account",
+ "model": modelID,
+ "stream": strconv.FormatBool(reqStream),
+ })
+
+ return req, nil
+}
+
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号,需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
@@ -6567,6 +6856,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false
}
+// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
+// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
+// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
+// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
+func sanitizeStreamError(err error) string {
+ if err == nil {
+ return ""
+ }
+ switch {
+ case errors.Is(err, io.ErrUnexpectedEOF):
+ return "unexpected EOF"
+ case errors.Is(err, io.EOF):
+ return "EOF"
+ case errors.Is(err, context.Canceled):
+ return "canceled"
+ case errors.Is(err, context.DeadlineExceeded):
+ return "deadline exceeded"
+ case errors.Is(err, syscall.ECONNRESET):
+ return "connection reset by peer"
+ case errors.Is(err, syscall.ECONNABORTED):
+ return "connection aborted"
+ case errors.Is(err, syscall.ETIMEDOUT):
+ return "connection timed out"
+ case errors.Is(err, syscall.EPIPE):
+ return "broken pipe"
+ case errors.Is(err, syscall.ECONNREFUSED):
+ return "connection refused"
+ }
+ var netErr *net.OpError
+ if errors.As(err, &netErr) {
+ if netErr.Timeout() {
+ if netErr.Op != "" {
+ return netErr.Op + " timeout"
+ }
+ return "i/o timeout"
+ }
+ if netErr.Op != "" {
+ return netErr.Op + " network error"
+ }
+ }
+ return "upstream connection error"
+}
+
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
@@ -7004,14 +7336,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
lastDataAt := time.Now()
- // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
+ // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
+ // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}}
+ // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
+ // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
errorEventSent := false
- sendErrorEvent := func(reason string) {
+ sendErrorEvent := func(reason, message string) {
if errorEventSent {
return
}
errorEventSent = true
- _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ if message == "" {
+ message = reason
+ }
+ body, err := json.Marshal(map[string]any{
+ "type": "error",
+ "error": map[string]string{
+ "type": reason,
+ "message": message,
+ },
+ })
+ if err != nil {
+ // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
+ body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
+ }
+ _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
flusher.Flush()
}
@@ -7088,9 +7437,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
- // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
- overrideTarget := account.GetCacheTTLOverrideTarget()
+ // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
@@ -7171,10 +7520,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
- sendErrorEvent("response_too_large")
+ sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
- sendErrorEvent("stream_read_error")
+ // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
+ // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
+ // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
+ // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
+ // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
+ // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
+ disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
+ if !c.Writer.Written() {
+ logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
+ body, _ := json.Marshal(map[string]any{
+ "type": "error",
+ "error": map[string]string{
+ "type": "upstream_disconnected",
+ "message": disconnectMsg,
+ },
+ })
+ return nil, &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: body,
+ RetryableOnSameAccount: true,
+ }
+ }
+ sendErrorEvent("stream_read_error", disconnectMsg)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
@@ -7233,7 +7604,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
- sendErrorEvent("stream_timeout")
+ sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
@@ -7475,6 +7846,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
return true
}
+func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
+ if account == nil {
+ return "", false
+ }
+ if account.IsCacheTTLOverrideEnabled() {
+ return account.GetCacheTTLOverrideTarget(), true
+ }
+ if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
+ return cacheTTLTarget5m, true
+ }
+ return "", false
+}
+
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -7511,9 +7895,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
- // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
- overrideTarget := account.GetCacheTTLOverrideTarget()
+ // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
@@ -8081,10 +8465,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
- // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ // Cache TTL Override: 确保计费时 token 分类与账号设置一致。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
cacheTTLOverridden := false
- if account.IsCacheTTLOverrideEnabled() {
- applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
+ applyCacheTTLOverride(&result.Usage, overrideTarget)
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go
index b1584827..ef09a882 100644
--- a/backend/internal/service/gateway_streaming_test.go
+++ b/backend/internal/service/gateway_streaming_test.go
@@ -4,9 +4,12 @@ package service
import (
"context"
+ "errors"
"io"
+ "net"
"net/http"
"net/http/httptest"
+ "syscall"
"testing"
"time"
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
body := rec.Body.String()
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
}
+
+// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
+// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
+func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
+ }
+
+ result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+
+ require.Error(t, err)
+ require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
+
+ var failoverErr *UpstreamFailoverError
+ require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
+
+ // ResponseBody 必须是 Anthropic 标准 error 格式:
+ // 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
+ // 2) error.type 标记为 upstream_disconnected
+ extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
+ require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
+ require.Contains(t, extractedMsg, "upstream stream disconnected")
+ require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
+ require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
+
+ // 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
+ require.NotContains(t, rec.Body.String(), "stream_read_error")
+}
+
+// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
+// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
+func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ // 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{
+ payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
+ require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
+
+ // 不应被错误地包成 failover error
+ var failoverErr *UpstreamFailoverError
+ require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
+
+ // 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
+ // error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
+ body := rec.Body.String()
+ require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
+ require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
+ require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
+ require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
+}
+
+// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
+// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
+// failover ResponseBody 或 SSE error 帧返回给客户端。
+func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
+ src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
+ require.NoError(t, err)
+ dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
+ require.NoError(t, err)
+
+ raw := &net.OpError{
+ Op: "read",
+ Net: "tcp",
+ Source: src,
+ Addr: dst,
+ Err: syscall.ECONNRESET,
+ }
+
+ // 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
+ require.Contains(t, raw.Error(), "10.0.0.1")
+ require.Contains(t, raw.Error(), "52.1.2.3")
+
+ got := sanitizeStreamError(raw)
+ require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
+ require.NotContains(t, got, "54321", "不得泄露源端口")
+ require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
+ require.NotContains(t, got, "443", "不得泄露上游端口")
+ require.Equal(t, "connection reset by peer", got)
+}
+
+func TestSanitizeStreamError_KnownErrors(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ want string
+ }{
+ {"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
+ {"EOF", io.EOF, "EOF"},
+ {"context canceled", context.Canceled, "canceled"},
+ {"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
+ {"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
+ {"EPIPE", syscall.EPIPE, "broken pipe"},
+ {"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
+ {"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
+ {"nil 返回空串", nil, ""},
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ require.Equal(t, tc.want, sanitizeStreamError(tc.err))
+ })
+ }
+}
+
+// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
+// 时携带内部地址信息。
+func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
+ dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
+ netErr := &net.OpError{
+ Op: "read",
+ Net: "tcp",
+ Source: src,
+ Addr: dst,
+ Err: syscall.ECONNRESET,
+ }
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{err: netErr},
+ }
+
+ _, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.True(t, errors.As(err, &failoverErr))
+
+ body := string(failoverErr.ResponseBody)
+ require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
+ require.NotContains(t, body, "54321")
+ require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
+ require.NotContains(t, body, "443")
+ // 仍然包含可诊断的根因
+ require.Contains(t, body, "connection reset by peer")
+ require.Contains(t, body, "upstream stream disconnected")
+}
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 7a24071b..ea0c0d7d 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
+ case AccountTypeServiceAccount:
+ // Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
+ // endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
+ return 999
default:
return 10
}
@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
+ case AccountTypeServiceAccount:
+ buildReq = func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ action := "generateContent"
+ if req.Stream {
+ action = "streamGenerateContent"
+ }
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
+ if err != nil {
+ return nil, "", err
+ }
+
+ restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }
+ requestIDHeader = "x-request-id"
+
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
+ case AccountTypeServiceAccount:
+ buildReq = func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
+ if err != nil {
+ return nil, "", err
+ }
+
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }
+ requestIDHeader = "x-request-id"
+
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go
index 7add3460..172b9411 100644
--- a/backend/internal/service/gemini_token_provider.go
+++ b/backend/internal/service/gemini_token_provider.go
@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
-// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
+// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
- if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
- return "", errors.New("not a gemini oauth account")
+ if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
+ return "", errors.New("not a gemini oauth or service account")
+ }
+ if account.Type == AccountTypeServiceAccount {
+ return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := GeminiTokenCacheKey(account)
@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
+func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
+ return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
+}
+
func GeminiTokenCacheKey(account *Account) string {
+ if account != nil && account.Type == AccountTypeServiceAccount {
+ if key, err := parseVertexServiceAccountKey(account); err == nil {
+ return vertexServiceAccountCacheKey(account, key)
+ }
+ }
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index e765d7e9..b256f1c7 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -53,6 +53,23 @@ const (
codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n"
)
+var openAIChatGPTInternalUnsupportedFields = []string{
+ "user",
+ "metadata",
+ "prompt_cache_retention",
+ "safety_identifier",
+ "stream_options",
+}
+
+var openAICodexOAuthUnsupportedFields = append([]string{
+ "max_output_tokens",
+ "max_completion_tokens",
+ "temperature",
+ "top_p",
+ "frequency_penalty",
+ "presence_penalty",
+}, openAIChatGPTInternalUnsupportedFields...)
+
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
@@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
- // Strip parameters unsupported by codex models via the Responses API.
- for _, key := range []string{
- "max_output_tokens",
- "max_completion_tokens",
- "temperature",
- "top_p",
- "frequency_penalty",
- "presence_penalty",
- // prompt_cache_retention is a newer Responses API parameter (cache TTL).
- // The ChatGPT internal Codex endpoint rejects it with
- // "Unsupported parameter: prompt_cache_retention". Defense-in-depth
- // for any OAuth path that reaches this transform — the Cursor
- // Responses-shape short-circuit in ForwardAsChatCompletions strips
- // it earlier too, but we keep this line so other OAuth callers are
- // equally protected.
- "prompt_cache_retention",
- } {
+ // Strip parameters unsupported by ChatGPT internal Codex endpoint.
+ for _, key := range openAICodexOAuthUnsupportedFields {
if _, ok := reqBody[key]; ok {
delete(reqBody, key)
result.Modified = true
@@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{
"type": "function",
- "function": map[string]any{
- "name": name,
- },
+ "name": name,
}
}
}
@@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool {
return false
}
choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
- if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) {
+ if choiceType == "" {
return false
}
+ modified := false
+ if choiceType == "function" {
+ name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
+ if name == "" {
+ if function, ok := choiceMap["function"].(map[string]any); ok {
+ name = strings.TrimSpace(firstNonEmptyString(function["name"]))
+ }
+ }
+ if name == "" {
+ reqBody["tool_choice"] = "auto"
+ return true
+ }
+ if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
+ choiceMap["name"] = name
+ modified = true
+ }
+ if _, ok := choiceMap["function"]; ok {
+ delete(choiceMap, "function")
+ modified = true
+ }
+ if !codexToolsContainFunctionName(reqBody["tools"], name) {
+ reqBody["tool_choice"] = "auto"
+ return true
+ }
+ return modified
+ }
+ if codexToolsContainType(reqBody["tools"], choiceType) {
+ return modified
+ }
reqBody["tool_choice"] = "auto"
return true
}
@@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool {
return false
}
+func codexToolsContainFunctionName(rawTools any, name string) bool {
+ tools, ok := rawTools.([]any)
+ if !ok || strings.TrimSpace(name) == "" {
+ return false
+ }
+ normalizedName := strings.TrimSpace(name)
+ for _, rawTool := range tools {
+ tool, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
+ continue
+ }
+ toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
+ if toolName == "" {
+ if function, ok := tool["function"].(map[string]any); ok {
+ toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
+ }
+ }
+ if toolName == normalizedName {
+ return true
+ }
+ }
+ return false
+}
+
func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
if len(input) == 0 {
return input, false
@@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
typ, _ := m["type"].(string)
+ // chatgpt.com codex backend (OAuth path) does not persist reasoning
+ // items because applyCodexOAuthTransform forces store=false. Any rs_*
+ // reference replayed in input is guaranteed to 404 upstream
+ // ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
+ if typ == "reasoning" {
+ continue
+ }
+
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 75f5c55c..87bb7162 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -1,6 +1,8 @@
package service
import (
+ "fmt"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
require.Equal(t, "custom", choice["type"])
}
+func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{
+ "type": "function",
+ "function": map[string]any{"name": "shell"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ choice, ok := reqBody["tool_choice"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function", choice["type"])
+ require.Equal(t, "shell", choice["name"])
+ require.NotContains(t, choice, "function")
+}
+
+func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{
+ "type": "function",
+ "function": map[string]any{"name": "missing"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ require.Equal(t, "auto", reqBody["tool_choice"])
+}
+
func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.4",
@@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
}
+func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "user": "user_123",
+ "metadata": map[string]any{"trace_id": "abc"},
+ "prompt_cache_retention": "24h",
+ "safety_identifier": "sid",
+ "stream_options": map[string]any{"include_usage": true},
+ "input": []any{
+ map[string]any{"role": "user", "content": "hi"},
+ },
+ }
+
+ result := applyCodexOAuthTransform(reqBody, true, false)
+
+ require.True(t, result.Modified)
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ require.NotContains(t, reqBody, field)
+ }
+}
+
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
})
}
}
+
+func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
+ // Reasoning items in input[] reference rs_* IDs that were emitted by
+ // chatgpt.com under store=false (forced by applyCodexOAuthTransform).
+ // They are never persisted upstream, so forwarding them produces a
+ // guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
+ // regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
+
+ build := func() []any {
+ return []any{
+ map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
+ map[string]any{
+ "type": "reasoning",
+ "id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
+ "summary": []any{},
+ },
+ map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
+ map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
+ }
+ }
+
+ for _, preserve := range []bool{true, false} {
+ preserve := preserve
+ t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
+ filtered := filterCodexInput(build(), preserve)
+
+ for _, raw := range filtered {
+ item, ok := raw.(map[string]any)
+ require.True(t, ok)
+ require.NotEqual(t, "reasoning", item["type"],
+ "reasoning items must be dropped from input on the OAuth path")
+ if id, ok := item["id"].(string); ok {
+ require.False(t, strings.HasPrefix(id, "rs_"),
+ "no item carrying an rs_* id should survive the filter")
+ }
+ }
+
+ // Sanity check: the non-reasoning items should still be present.
+ gotTypes := make(map[string]int)
+ for _, raw := range filtered {
+ item, ok := raw.(map[string]any)
+ require.True(t, ok)
+ typ, ok := item["type"].(string)
+ require.True(t, ok)
+ gotTypes[typ]++
+ }
+ require.Equal(t, 1, gotTypes["message"])
+ require.Equal(t, 1, gotTypes["function_call"])
+ require.Equal(t, 1, gotTypes["function_call_output"])
+ require.Equal(t, 0, gotTypes["reasoning"])
+ })
+ }
+}
diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go
new file mode 100644
index 00000000..b52da614
--- /dev/null
+++ b/backend/internal/service/openai_fast_policy_test.go
@@ -0,0 +1,286 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIFastPolicyRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", ErrSettingNotFound
+}
+
+func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ return nil
+}
+
+func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
+ t.Helper()
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ if settings != nil {
+ raw, err := json.Marshal(settings)
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
+ }
+ return &OpenAIGatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
+func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
+ // 是用户级开关,与 model 正交。
+ // gpt-5.5 + priority → filter
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-5.5-turbo → filter
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-4 + priority → filter(默认策略覆盖所有模型)
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-5.5 + flex → pass (tier doesn't match)
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
+ require.Equal(t, BetaPolicyActionPass, action)
+
+ // empty tier → pass
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
+ require.Equal(t, BetaPolicyActionPass, action)
+}
+
+func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "fast mode is not allowed",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionBlock, action)
+ require.Equal(t, "fast mode is not allowed", msg)
+}
+
+func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierAny,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeOAuth,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+
+ // OAuth account → rule matches
+ oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // API Key account → rule skipped → pass
+ apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionPass, action)
+}
+
+func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // gpt-5.5 fast → service_tier stripped
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // Client sending "fast" (alias for priority) also filtered
+ body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
+ body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // No service_tier → no-op
+ body = []byte(`{"model":"gpt-5.5"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
+// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
+// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
+func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ for _, tier := range []string{"auto", "default", "scale"} {
+ body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err, "tier %q should pass without error", tier)
+ require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
+ "tier %q should be preserved in body under default rule", tier)
+ }
+
+ // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
+ for _, tier := range []string{"auto", "default", "scale"} {
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
+ require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
+ }
+}
+
+// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
+// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
+// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
+func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierAny,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
+ body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`,
+ "tier %q should be stripped under ServiceTier=all + filter rule", tier)
+ }
+}
+
+// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
+// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
+// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
+// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
+func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // normalize 阶段会将未知值剥离
+ require.Nil(t, normalizeOpenAIServiceTier("xxx"))
+
+ // applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
+ // (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
+ body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "fast mode is blocked for gpt-5.5",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.Error(t, err)
+ var blocked *OpenAIFastBlockedError
+ require.True(t, errors.As(err, &blocked))
+ require.Contains(t, blocked.Message, "fast mode is blocked")
+ require.Equal(t, string(body), string(updated)) // body not mutated on block
+}
+
+func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ svc := NewSettingService(repo, &config.Config{})
+
+ // Invalid action rejected
+ err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: "bogus",
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.Error(t, err)
+
+ // Invalid service_tier rejected
+ err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: "turbo",
+ Action: BetaPolicyActionPass,
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.Error(t, err)
+
+ // Valid settings persisted
+ err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.NoError(t, err)
+
+ got, err := svc.GetOpenAIFastPolicySettings(context.Background())
+ require.NoError(t, err)
+ require.Len(t, got.Rules, 1)
+ require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
+}
diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go
new file mode 100644
index 00000000..3316a242
--- /dev/null
+++ b/backend/internal/service/openai_fast_policy_ws_test.go
@@ -0,0 +1,1018 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate ---
+
+func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
+ // Other fields preserved.
+ require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String())
+ require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String())
+ require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String())
+}
+
+func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Verbatim "fast" → normalized to "priority" → matches default rule → filter.
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // Mixed-case + whitespace variant should also normalize and filter.
+ frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`)
+}
+
+func TestWSResponseCreate_FlexPassThrough(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Default policy targets priority only; flex is left untouched.
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy")
+}
+
+func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "ws fast blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.NotNil(t, blocked)
+ require.Equal(t, "ws fast blocked", blocked.Message)
+ // On block, payload returned unchanged so caller can inspect / log it.
+ require.Equal(t, string(frame), string(updated))
+}
+
+func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation")
+}
+
+func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"*"},
+ FallbackAction: BetaPolicyActionFilter,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // response.cancel happens to carry a service_tier-shaped field — must not be touched.
+ frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated))
+}
+
+// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the
+// helper used to treat empty type as response.create, which risked stripping
+// fields from malformed / unknown client events. After the A1 fix only a
+// strict "response.create" match triggers policy.
+func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"*"},
+ FallbackAction: BetaPolicyActionFilter,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Frame with no "type" field: must pass through completely unchanged
+ // even with a service_tier-shaped field present.
+ frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through")
+
+ // Explicit empty string also passes through.
+ frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated))
+}
+
+// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1
+// regression: the rendered Realtime error event must carry a non-empty
+// event_id (so clients can correlate the rejection) and a stable error.code
+// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error
+// JSON body emitted by writeOpenAIFastPolicyBlockedResponse.
+func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) {
+ bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"})
+ require.NotNil(t, bytes)
+
+ require.Equal(t, "error", gjson.GetBytes(bytes, "type").String())
+ require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String())
+ require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String())
+ require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String())
+
+ eventID := gjson.GetBytes(bytes, "event_id").String()
+ require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs")
+ require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_ Realtime convention; got %q", eventID)
+
+ // Sanity check: two consecutive events get distinct IDs.
+ other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"})
+ otherID := gjson.GetBytes(other, "event_id").String()
+ require.NotEqual(t, eventID, otherID, "event_id must be random per-event")
+}
+
+// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns
+// nil for a nil error (defensive guard for callers that always invoke it).
+func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) {
+ require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil))
+}
+
+// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback ---
+
+// fakePassthroughFrameConn replays a fixed sequence of client frames into the
+// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts
+// for write-side assertions (none expected in the D5 test, since the wrapper
+// only filters reads).
+type fakePassthroughFrameConn struct {
+ reads [][]byte
+ idx int
+ writes [][]byte
+ closeOnce bool
+}
+
+func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if f.idx >= len(f.reads) {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ payload := f.reads[f.idx]
+ f.idx++
+ return coderws.MessageText, payload, nil
+}
+
+func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ cp := append([]byte(nil), payload...)
+ f.writes = append(f.writes, cp)
+ return nil
+}
+
+func (f *fakePassthroughFrameConn) Close() error {
+ f.closeOnce = true
+ return nil
+}
+
+// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于
+// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时
+// fallback 路径无法被观察到)。
+func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
+ return &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+}
+
+// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is
+// the D5 regression: in passthrough mode a follow-up response.create frame
+// without a "model" field must still hit the policy via the session-level
+// model captured from the first frame. Without the fallback an empty model
+// would miss a model whitelist and silently leak service_tier=priority
+// through to the upstream.
+func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) {
+ // 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel
+ // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致,
+ // 不能用来覆盖此回归)。
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Simulate the passthrough adapter capturing model from the first frame.
+ firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
+ require.Equal(t, "gpt-5.5", capturedSessionModel)
+
+ // Follow-up frame deliberately omits "model" — Realtime allows this.
+ followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
+
+ inner := &fakePassthroughFrameConn{
+ reads: [][]byte{followupFrame},
+ }
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ // Read the follow-up frame through the wrapper. The policy MUST still
+ // trigger filter (gpt-5.5 + priority → filter), so the service_tier
+ // field is gone by the time the relay sees it.
+ _, payload, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.NotContains(t, string(payload), `"service_tier"`,
+ "D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5")
+ require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String())
+}
+
+// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the
+// inverse: when the wrapper has NO capturedSessionModel fallback (model is
+// empty per-frame and no fallback is wired up), the policy fails to match
+// the model whitelist and the frame leaks through unchanged. This documents
+// exactly the leak the D5 fix prevents.
+func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) {
+ // 同样使用带 whitelist 的策略以观察 leak。
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
+ inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}}
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ // NO fallback — emulate the pre-fix behavior.
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ _, payload, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ // Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept.
+ require.Contains(t, string(payload), `"service_tier"`,
+ "sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing")
+}
+
+// --- Ingress end-to-end test (filter path) ---
+
+// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the
+// real ProxyResponsesWebSocketFromClient ingress session pipeline against a
+// captureConn upstream and asserts that a client frame with service_tier=fast
+// is normalized + filtered out before being written upstream. This is the
+// integration flavour of TestWSResponseCreate_FilterStripsServiceTier.
+func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings())
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ settingService: NewSettingService(repo, cfg),
+ }
+
+ account := &Account{
+ ID: 901,
+ Name: "openai-ws-filter",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() { _ = conn.CloseNow() }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ _, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() { _ = clientConn.CloseNow() }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`)))
+ cancelWrite()
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, event, readErr := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, readErr)
+ require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create")
+ upstream := captureConn.writes[0]
+ _, hasServiceTier := upstream["service_tier"]
+ require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)")
+ require.Equal(t, "response.create", upstream["type"])
+ require.Equal(t, "gpt-5.5", upstream["model"])
+}
+
+// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the
+// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It
+// asserts that with a custom block rule, the client receives a Realtime-style
+// error event AND the upstream FrameConn never receives the offending frame.
+func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ // No events queued; the upstream should never get written to anyway.
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ blockSettings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "ws priority blocked for testing",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ raw, err := json.Marshal(blockSettings)
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ settingService: NewSettingService(repo, cfg),
+ }
+
+ account := &Account{
+ ID: 902,
+ Name: "openai-ws-block",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() { _ = conn.CloseNow() }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ _, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ // Mirror the production handler (openai_gateway_handler.go:1325-1328):
+ // when the proxy returns an OpenAIWSClientCloseError, surface its
+ // status code to the client via a graceful close handshake. Without
+ // this the deferred CloseNow() above would tear down the TCP
+ // connection without sending a close frame, and the C3 timing
+ // assertion (next read returns CloseStatus=1008) would see EOF
+ // instead.
+ var closeErr *OpenAIWSClientCloseError
+ if errors.As(proxyErr, &closeErr) {
+ _ = conn.Close(closeErr.StatusCode(), closeErr.Reason())
+ }
+ serverErrCh <- proxyErr
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() { _ = clientConn.CloseNow() }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`)))
+ cancelWrite()
+
+ // C3 timing assertion: the FIRST frame the client reads must be the
+ // error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is
+ // synchronous (writeFrame Flushes the bufio writer at write.go:307-311
+ // before returning) and the close handshake re-acquires the same
+ // writeFrameMu, so this ordering is enforced by the library itself; this
+ // assertion guards against future refactors that might break it.
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, event, readErr := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, readErr, "first read must succeed and return the error event before any close frame")
+ require.Equal(t, "error", gjson.GetBytes(event, "type").String())
+ require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String())
+ // B1 regression: event_id + error.code must be populated.
+ require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String())
+ require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate")
+ require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing")
+
+ // Next read must surface the close frame (as a CloseError). This
+ // asserts the [error event, close] ordering — i.e. the close did NOT
+ // race ahead of the data frame.
+ readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, secondReadErr := clientConn.Read(readCtx2)
+ cancelRead2()
+ require.Error(t, secondReadErr, "after the error event the connection must surface a close")
+ require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr),
+ "close status must be PolicyViolation; got %v", secondReadErr)
+
+ select {
+ case serverErr := <-serverErrCh:
+ // Server returns an OpenAIWSClientCloseError — handler closes the WS;
+ // here we just assert it surfaced as the typed close error.
+ require.Error(t, serverErr)
+ var closeErr *OpenAIWSClientCloseError
+ require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError,得到 %T: %v", serverErr, serverErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress 关闭超时")
+ }
+
+ // Critical: the offending frame must NEVER reach the upstream.
+ // captureDialer.DialCount may legitimately be 0 or 1 depending on whether
+ // the lease was acquired before policy fired; either way, no writes.
+ require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create")
+}
+
+// --- HTTP-side gap-filling tests (already covered by existing tests but
+// requested to be split out explicitly) ---
+
+// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that
+// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule
+// action is "block", and that the body is left untouched. The caller (chat
+// completions / messages handlers) inspects this typed error and skips the
+// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and
+// openai_gateway_messages.go:149.
+func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "priority blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.Error(t, err)
+ var blocked *OpenAIFastBlockedError
+ require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request")
+ require.Equal(t, "priority blocked", blocked.Message)
+ require.Equal(t, string(body), string(updated), "block must not mutate body")
+}
+
+// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies
+// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode
+// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60)
+// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has
+// no service_tier. We exercise the same internal pipeline (Anthropic→Responses
+// + BetaFastMode + policy) without spinning up a real upstream HTTP server.
+func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50).
+ anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)
+ var anthropicReq apicompat.AnthropicRequest
+ require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq))
+ responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
+ require.NoError(t, err)
+
+ // Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61).
+ headers := http.Header{}
+ headers.Set("anthropic-beta", claude.BetaFastMode)
+ require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode))
+ responsesReq.ServiceTier = "priority"
+ responsesReq.Model = "gpt-5.5"
+
+ // Step 3: marshal & apply fast policy (mirrors line 78 + 149).
+ responsesBody, err := json.Marshal(responsesReq)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置:beta 翻译应当注入 priority")
+
+ upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody)
+ require.NoError(t, policyErr)
+
+ // Step 4: assert that policy filtered the field before the upstream HTTP request.
+ require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier")
+}
+
+// --- Fix1: passthrough capturedSessionModel must follow session.update ---
+
+// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the
+// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under
+// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends
+// response.create without "model". Without the session.update sniffing the
+// follow-up frame would fall back to the stale gpt-4o capture and pass — the
+// fix updates capturedSessionModel from session.* events so the fallback now
+// resolves to gpt-5.5 and the policy filters service_tier.
+func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Frame 1: response.create with whitelist-miss model — under default
+ // rule fallback=pass, service_tier stays.
+ first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`)
+ // Frame 2: session.update rotates the session model to gpt-5.5.
+ rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`)
+ // Frame 3: response.create WITHOUT model — must inherit gpt-5.5.
+ followup := []byte(`{"type":"response.create","service_tier":"priority"}`)
+
+ inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}}
+
+ // Replicate the production wiring in openai_ws_v2_passthrough_adapter.go
+ // so capturedSessionModel state is shared across frames.
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first)
+ require.Equal(t, "gpt-4o", capturedSessionModel)
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ // Frame 1: gpt-4o miss whitelist → pass (service_tier preserved).
+ _, payload1, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier")
+
+ // Frame 2: session.update — not response.create, untouched, but its
+ // side effect updates capturedSessionModel to gpt-5.5.
+ _, payload2, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim")
+ require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel")
+
+ // Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter.
+ _, payload3, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.NotContains(t, string(payload3), `"service_tier"`,
+ "fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter")
+}
+
+// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative
+// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only
+// client→upstream session.update frames rotate the captured model;
+// server→client events (session.created) and unrelated frames must not.
+func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) {
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // session.created is a server→client event in the OpenAI Realtime
+ // protocol — clients never send it, so this filter (which only runs on
+ // the client→upstream direction) must ignore it even if it appears.
+ created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created))
+
+ // Non-session.* frames must NOT trigger rotation.
+ notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession))
+
+ // Missing session.model returns empty — caller keeps the old captured value.
+ noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel))
+}
+
+// --- Fix2: native /responses normalize "fast" → "priority" on pass ---
+
+// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2
+// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody
+// returned the body unchanged so a raw "fast" alias would leak to the
+// upstream OpenAI API (which does not accept "fast"). The fix normalizes
+// "fast" → "priority" on pass too.
+func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) {
+ // Use a policy that deliberately misses gpt-4 so the action is pass.
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority".
+ body := []byte(`{"model":"gpt-4","service_tier":"fast"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(),
+ "fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug")
+
+ // Already-canonical "priority" on pass: zero mutation (byte-equal).
+ body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+
+ // Mixed-case alias → normalized.
+ body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
+
+ // Unrecognized tier → still no-op (not normalized, since normTier == "").
+ body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+// --- Fix3: passthrough billing must reflect post-filter service_tier ---
+
+// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The
+// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts
+// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy
+// has rewritten it, so a filter hit causes billing to report nil (default
+// tier) instead of the user-requested "priority". This test pins the
+// contract those two helpers must uphold for the adapter's billing path.
+func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+
+ // Pre-filter sanity: extracting from the raw frame would (incorrectly,
+ // pre-fix) report "priority" — this is the very thing the adapter
+ // must NOT do anymore.
+ pre := extractOpenAIServiceTierFromBody(raw)
+ require.NotNil(t, pre)
+ require.Equal(t, "priority", *pre,
+ "sanity: raw first frame carries priority that pre-fix billing would have reported")
+
+ // Apply policy filter (default rule: gpt-5.5 + priority → filter).
+ filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(filtered), `"service_tier"`)
+
+ // Post-filter: extracting from the rewritten frame returns nil. This
+ // is the value the adapter now passes to OpenAIForwardResult.ServiceTier,
+ // so billing records "default" instead of "priority".
+ post := extractOpenAIServiceTierFromBody(filtered)
+ require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority")
+
+ // And the byte-level invariant the adapter relies on: filtering an
+ // already-filtered frame is a no-op (idempotent), so re-running the
+ // policy doesn't accidentally re-introduce the field.
+ again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered)
+ require.NoError(t, err)
+ require.Nil(t, blocked2)
+ require.Equal(t, string(filtered), string(again),
+ "policy is idempotent: filtering an already-filtered frame leaves bytes unchanged")
+}
+
+// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap
+// flagged in the review: when a client sends service_tier as a non-string
+// (number, null, object, etc.) the policy must NOT panic and must NOT
+// pretend the field was filtered. Behavior: skip policy entirely (treat as
+// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's
+// type-assertion `reqBody["service_tier"].(string); ok` guard.
+func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Number — gjson .String() coerces to "1" which is not a recognized
+ // tier alias; normalize returns "" → policy no-ops.
+ cases := [][]byte{
+ []byte(`{"model":"gpt-5.5","service_tier":1}`),
+ []byte(`{"model":"gpt-5.5","service_tier":null}`),
+ []byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`),
+ []byte(`{"model":"gpt-5.5","service_tier":["priority"]}`),
+ []byte(`{"model":"gpt-5.5","service_tier":true}`),
+ }
+ for _, body := range cases {
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err, "non-string service_tier must not error: %s", string(body))
+ require.Equal(t, string(body), string(updated),
+ "non-string service_tier must pass through unchanged: %s", string(body))
+ }
+
+ // Same guard for the WS response.create entry.
+ for _, body := range cases {
+ frame := body
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame))
+ require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame))
+ require.Equal(t, string(frame), string(updated),
+ "non-string service_tier ws frame must pass through unchanged: %s", string(frame))
+ }
+}
+
+// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the
+// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS
+// allows the client to ship a different service_tier on each response.create
+// frame (per-response field, see codex-rs/core/src/client.rs
+// build_responses_request which re-fills the field on every request). Before
+// the fix the adapter only captured service_tier from firstClientMessage so
+// turn 2/3 billing was wrong. After the fix the filter closure refreshes an
+// atomic.Pointer[string] on every successful response.create frame.
+//
+// This test pins the four legs of the semantic contract:
+// - turn 1: service_tier=priority hits the default whitelist filter, so
+// after filter the upstream sees no tier → billing is nil.
+// - turn 2: service_tier=flex passes (default rule targets priority only),
+// billing should now reflect "flex".
+// - turn 3: response.create without any service_tier — the upstream will
+// treat it as default; we choose to mirror that and overwrite billing
+// to nil rather than carry over "flex" from turn 2.
+// - non-response.create frame (response.cancel here) carrying a stray
+// service_tier-shaped field must NOT clobber the billing pointer.
+func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go
+ // proxyResponsesWebSocketV2Passthrough) so this test fails if the
+ // production code drops the per-frame Store.
+ var requestServiceTierPtr atomic.Pointer[string]
+ capturedSessionModel := ""
+ filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ }
+
+ // First-frame initialization mirrors the adapter: extract from the
+ // post-filter payload so a filter-on-first-frame zeroes billing too.
+ firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame)
+ require.NoError(t, firstErr)
+ require.Nil(t, firstBlocked)
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut))
+ capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
+ require.Nil(t, requestServiceTierPtr.Load(),
+ "turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier")
+
+ // Turn 2: client switches to flex, should pass and update billing.
+ turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
+ out2, blocked2, err2 := filter(coderws.MessageText, turn2)
+ require.NoError(t, err2)
+ require.Nil(t, blocked2)
+ require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched")
+ tier2 := requestServiceTierPtr.Load()
+ require.NotNil(t, tier2, "turn 2: billing must update to reflect flex")
+ require.Equal(t, "flex", *tier2)
+
+ // A non-response.create frame with a stray service_tier-shaped field
+ // must NOT overwrite the billing pointer (those frames don't carry
+ // per-response service_tier in the Realtime spec).
+ cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
+ _, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame)
+ require.NoError(t, errCancel)
+ require.Nil(t, blockedCancel)
+ tierAfterCancel := requestServiceTierPtr.Load()
+ require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil")
+ require.Equal(t, "flex", *tierAfterCancel,
+ "non-response.create frames must not update billing tier even if they carry a service_tier-shaped field")
+
+ // Turn 3: response.create without any service_tier. We deliberately
+ // overwrite billing back to nil so it tracks what the upstream actually
+ // sees on this turn (default tier).
+ turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`)
+ out3, blocked3, err3 := filter(coderws.MessageText, turn3)
+ require.NoError(t, err3)
+ require.Nil(t, blocked3)
+ require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate")
+ require.Nil(t, requestServiceTierPtr.Load(),
+ "turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
+}
+
+// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
+// "block keeps previous" semantic: when policy returns block on a
+// response.create frame, that frame is never sent upstream, so billing tier
+// must keep the previous turn's value rather than getting silently zeroed.
+func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) {
+ blockSettings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, blockSettings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ var requestServiceTierPtr atomic.Pointer[string]
+ flexValue := "flex"
+ requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex
+
+ filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload)
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ }
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ _, blocked, err := filter(coderws.MessageText, frame)
+ require.NoError(t, err)
+ require.NotNil(t, blocked, "policy must block this frame")
+
+ tier := requestServiceTierPtr.Load()
+ require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil")
+ require.Equal(t, "flex", *tier,
+ "blocked frame is never sent upstream; billing must retain the previous turn's tier")
+}
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index 663066a3..5822ae4c 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
}
+ // 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
+ }
+ return nil, policyErr
+ }
+ responsesBody = updatedBody
+
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
index a00fb71c..6846e03a 100644
--- a/backend/internal/service/openai_gateway_chat_completions_test.go
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
+ // OpenAI 官方合法 tier 应被透传保留。
+ req.ServiceTier = "auto"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "auto", req.ServiceTier)
+
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "default", req.ServiceTier)
+
+ req.ServiceTier = "scale"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "scale", req.ServiceTier)
+
+ // 真未知值仍被剥离。
+ req.ServiceTier = "turbo"
+ normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
+ // OpenAI 官方 tier 直接保留在 body 中(透传上游)。
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
+ require.NoError(t, err)
+ require.Equal(t, "auto", tier)
+ require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
+
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
+ require.Equal(t, "default", tier)
+ require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
+ require.NoError(t, err)
+ require.Equal(t, "scale", tier)
+ require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
+
+ // 真未知值才会被删除。
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
+ require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}
diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go
index 2a0a72eb..4e0ebb2e 100644
--- a/backend/internal/service/openai_gateway_messages.go
+++ b/backend/internal/service/openai_gateway_messages.go
@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
+ // 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
+ // Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
+ // on the body-level service_tier field (priority/flex).
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
+ }
+ return nil, policyErr
+ }
+ responsesBody = updatedBody
+
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index 9665c4c8..47ff4e3b 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
+ nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got)
})
- t.Run("default ignored", func(t *testing.T) {
- require.Nil(t, normalizeOpenAIServiceTier("default"))
+ t.Run("openai official tiers preserved", func(t *testing.T) {
+ // OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
+ // 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
+ // 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
+ for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
+ got := normalizeOpenAIServiceTier(tier)
+ require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
+ require.Equal(t, tier, *got)
+ }
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
+ require.Nil(t, normalizeOpenAIServiceTier("xxx"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
+ require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
+ require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
+ require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
- require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
+ require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
+ require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
+ require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
+ require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 379ebe0b..ed69730c 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
+ settingService *SettingService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
+ settingService *SettingService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
+ settingService: settingService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID
}
+func explicitOpenAISessionID(c *gin.Context, body []byte) string {
+ if c == nil {
+ return ""
+ }
+
+ sessionID := strings.TrimSpace(c.GetHeader("session_id"))
+ if sessionID == "" {
+ sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
+ }
+ if sessionID == "" && len(body) > 0 {
+ sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
+ }
+ return sessionID
+}
+
+// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
+// client session signals. It intentionally skips content-derived fallback and is
+// used by stateless endpoints such as /v1/images.
+func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
+ sessionID := explicitOpenAISessionID(c, body)
+ if sessionID == "" {
+ return ""
+ }
+
+ currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
+ attachOpenAILegacySessionHashToGin(c, legacyHash)
+ return currentHash
+}
+
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
@@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
- sessionID := strings.TrimSpace(c.GetHeader("session_id"))
- if sessionID == "" {
- sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
- }
- if sessionID == "" && len(body) > 0 {
- sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
- }
+ sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}
@@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
}
+ // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
+ // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
+ // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
+ // fast 时在此生效。
+ //
+ // 注意:
+ // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
+ // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
+ // chat-completions / messages 入口保持一致,避免不同入口因为模型
+ // 维度不同而出现 whitelist 命中差异。
+ // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
+ // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
+ // completions 入口由 normalizeResponsesBodyServiceTier 完成同一
+ // 行为,这里手工实现等效逻辑。
+ if rawTier, ok := reqBody["service_tier"].(string); ok {
+ if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
+ }
+ blocked := &OpenAIFastBlockedError{Message: msg}
+ writeOpenAIFastPolicyBlockedResponse(c, blocked)
+ return nil, blocked
+ case BetaPolicyActionFilter:
+ delete(reqBody, "service_tier")
+ bodyModified = true
+ disablePatch()
+ default:
+ // pass:若客户端传的是别名 "fast",归一化为 "priority"
+ // 后写回 body,确保上游收到的是其能识别的规范值。
+ if normTier != rawTier {
+ reqBody["service_tier"] = normTier
+ bodyModified = true
+ markPatchSet("service_tier", normTier)
+ }
+ }
+ }
+ }
+
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
body = sanitizedBody
}
+ // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
+ // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
+ // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
+ // 这样可以与 chat-completions / messages / native /responses 入口的
+ // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
+ // model 字段时退回 reqModel。
+ policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
+ if policyModel == "" {
+ policyModel = reqModel
+ }
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeOpenAIFastPolicyBlockedResponse(c, blocked)
+ }
+ return nil, policyErr
+ }
+ body = updatedBody
+
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@@ -4841,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
}
normalized := []byte(`{}`)
- for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
+ // Keep the current Codex /compact schema while still dropping request-scoped
+ // fields such as prompt_cache_key, store, and stream.
+ for _, field := range []string{
+ "model",
+ "input",
+ "instructions",
+ "tools",
+ "parallel_tool_calls",
+ "reasoning",
+ "text",
+ "previous_response_id",
+ } {
value := gjson.GetBytes(body, field)
if !value.Exists() {
continue
@@ -5454,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
}
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
-// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
+// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
+// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
@@ -5463,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
normalized := body
changed := false
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ if value := gjson.GetBytes(normalized, field); !value.Exists() {
+ continue
+ }
+ next, err := sjson.DeleteBytes(normalized, field)
+ if err != nil {
+ return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
+ }
+ normalized = next
+ changed = true
+ }
+
if compact {
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
next, err := sjson.DeleteBytes(normalized, "store")
@@ -5567,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
if value == "fast" {
value = "priority"
}
+ // 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
+ // 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
+ // 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
+ // 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
switch value {
- case "priority", "flex":
+ case "priority", "flex", "auto", "default", "scale":
return &value
default:
return nil
}
}
+// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
+// policy (action=block). Mirrors BetaBlockedError on the Claude side.
+type OpenAIFastBlockedError struct {
+ Message string
+}
+
+func (e *OpenAIFastBlockedError) Error() string { return e.Message }
+
+// evaluateOpenAIFastPolicy returns the action and error message that should be
+// applied for a request with the given account/model/service_tier. When the
+// policy service is unavailable or no rule matches, it returns
+// (BetaPolicyActionPass, "") so callers can short-circuit safely.
+//
+// Matching rules:
+// - Scope filters by account type (all / oauth / apikey / bedrock)
+// - ServiceTier must be empty (= any), "all", or equal the normalized tier
+// - ModelWhitelist narrows the rule to specific models; FallbackAction
+// handles the non-matching case (default: pass)
+//
+// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
+// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
+// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
+// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
+// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
+// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
+// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
+// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
+func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
+ if s == nil || s.settingService == nil {
+ return BetaPolicyActionPass, ""
+ }
+ tier := strings.ToLower(strings.TrimSpace(serviceTier))
+ if tier == "" {
+ return BetaPolicyActionPass, ""
+ }
+ settings := openAIFastPolicySettingsFromContext(ctx)
+ if settings == nil {
+ fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
+ if err != nil || fetched == nil {
+ return BetaPolicyActionPass, ""
+ }
+ settings = fetched
+ }
+ return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
+}
+
+// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
+// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
+// the settingService on every frame. See WSSession entry and
+// openAIFastPolicySettingsFromContext for the caching glue.
+func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
+ if settings == nil {
+ return BetaPolicyActionPass, ""
+ }
+ isOAuth := account != nil && account.IsOAuth()
+ isBedrock := account != nil && account.IsBedrock()
+ for _, rule := range settings.Rules {
+ if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
+ continue
+ }
+ ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
+ if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
+ continue
+ }
+ eff := BetaPolicyRule{
+ Action: rule.Action,
+ ErrorMessage: rule.ErrorMessage,
+ ModelWhitelist: rule.ModelWhitelist,
+ FallbackAction: rule.FallbackAction,
+ FallbackErrorMessage: rule.FallbackErrorMessage,
+ }
+ return resolveRuleAction(eff, model)
+ }
+ return BetaPolicyActionPass, ""
+}
+
+// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
+// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
+//
+// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
+// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
+// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
+// 可以通过踢断 session 强制刷新。
+type openAIFastPolicyCtxKeyType struct{}
+
+var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
+
+// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
+// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
+func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
+ if ctx == nil || settings == nil {
+ return ctx
+ }
+ return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
+}
+
+func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
+ if ctx == nil {
+ return nil
+ }
+ if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
+ return v
+ }
+ return nil
+}
+
+// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
+// body. When action=filter it removes the service_tier field; when
+// action=block it returns (body, *OpenAIFastBlockedError). On pass it
+// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
+// rewriting the body so the upstream receives a slug it recognizes.
+//
+// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
+// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
+// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
+// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
+// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
+func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
+ if len(body) == 0 {
+ return body, nil
+ }
+ rawTier := gjson.GetBytes(body, "service_tier").String()
+ if rawTier == "" {
+ return body, nil
+ }
+ normTier := normalizedOpenAIServiceTierValue(rawTier)
+ if normTier == "" {
+ return body, nil
+ }
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
+ }
+ return body, &OpenAIFastBlockedError{Message: msg}
+ case BetaPolicyActionFilter:
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ if err != nil {
+ return body, fmt.Errorf("strip service_tier from body: %w", err)
+ }
+ return trimmed, nil
+ default:
+ // pass:把别名(如 "fast")写回为规范值("priority")。
+ if normTier == rawTier {
+ return body, nil
+ }
+ updated, err := sjson.SetBytes(body, "service_tier", normTier)
+ if err != nil {
+ return body, fmt.Errorf("normalize service_tier on pass: %w", err)
+ }
+ return updated, nil
+ }
+}
+
+// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
+// request blocked by the OpenAI fast policy.
+func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
+ if c == nil || err == nil {
+ return
+ }
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": gin.H{
+ "type": "permission_error",
+ "message": err.Message,
+ },
+ })
+}
+
+// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
+// against a single client→upstream WebSocket frame whose top-level
+// "type"=="response.create". It mirrors the HTTP-side
+// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
+// WS payload:
+//
+// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
+// - filter: returns a copy with top-level service_tier removed
+// - block: returns (frame, *OpenAIFastBlockedError)
+//
+// Only frames whose "type" field strictly equals "response.create" are
+// inspected/mutated. Any other frame type — including the empty string —
+// passes through untouched. The OpenAI Realtime client-event spec requires
+// "type" to be set, so an empty type is treated as a malformed frame we do
+// not police; the upstream is the source of truth for rejecting it.
+//
+// service_tier lives at the top level of response.create — same as the
+// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
+// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
+// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
+// to inspect / strip the top-level field; there is no nested form in the
+// schema today.
+//
+// The caller is responsible for choosing the upstream model passed in —
+// this helper does not re-derive it.
+func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
+ ctx context.Context,
+ account *Account,
+ model string,
+ frame []byte,
+) ([]byte, *OpenAIFastBlockedError, error) {
+ if len(frame) == 0 {
+ return frame, nil, nil
+ }
+ if !gjson.ValidBytes(frame) {
+ return frame, nil, nil
+ }
+ frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
+ // Strict match: only response.create is policy-checked. Empty / other
+ // types pass through untouched so we never accidentally strip fields
+ // from response.cancel, conversation.item.create, or any future
+ // client-event the spec adds. The Realtime spec requires "type" on
+ // every client event, so an empty type is malformed input — let the
+ // upstream reject it rather than guessing at our layer.
+ if frameType != "response.create" {
+ return frame, nil, nil
+ }
+ rawTier := gjson.GetBytes(frame, "service_tier").String()
+ if rawTier == "" {
+ return frame, nil, nil
+ }
+ normTier := normalizedOpenAIServiceTierValue(rawTier)
+ if normTier == "" {
+ return frame, nil, nil
+ }
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
+ }
+ return frame, &OpenAIFastBlockedError{Message: msg}, nil
+ case BetaPolicyActionFilter:
+ trimmed, err := sjson.DeleteBytes(frame, "service_tier")
+ if err != nil {
+ return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
+ }
+ return trimmed, nil, nil
+ default:
+ return frame, nil, nil
+ }
+}
+
+// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
+// server-emitted error event. Matches the loose "evt_" convention used
+// by upstream Realtime servers; the exact value is not load-bearing and is
+// only required for client-side log correlation. We reuse the existing
+// google/uuid dependency rather than pulling a new one.
+func newOpenAIFastPolicyWSEventID() string {
+ id, err := uuid.NewRandom()
+ if err != nil {
+ // Extremely unlikely; fall back to a fixed prefix so the field is
+ // still non-empty and the schema stays self-consistent.
+ return "evt_openai_fast_policy"
+ }
+ // Strip dashes so it visually matches "evt_" rather than UUID v4
+ // canonical form, mirroring what real Realtime traces look like.
+ return "evt_" + strings.ReplaceAll(id.String(), "-", "")
+}
+
+// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
+// style "error" event payload for a request blocked by the OpenAI fast
+// policy. The shape mirrors Realtime error events as observed in upstream
+// traces and per the spec's server "error" event:
+//
+// {
+// "event_id": "evt_",
+// "type": "error",
+// "error": {
+// "type": "invalid_request_error",
+// "code": "policy_violation",
+// "message": "..."
+// }
+// }
+//
+// event_id lets clients correlate the rejection in their logs; "code" gives
+// programmatic clients a stable identifier (HTTP-side equivalent is the
+// 403 permission_error JSON body).
+func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
+ if err == nil {
+ return nil
+ }
+ eventID := newOpenAIFastPolicyWSEventID()
+ payload, mErr := json.Marshal(map[string]any{
+ "event_id": eventID,
+ "type": "error",
+ "error": map[string]any{
+ "type": "invalid_request_error",
+ "code": "policy_violation",
+ "message": err.Message,
+ },
+ })
+ if mErr != nil {
+ // Fallback to a minimal hand-rolled payload; Marshal of the literal
+ // shape above should never fail in practice.
+ return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
+ }
+ return payload
+}
+
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 9c08d5f2..5d1c6fc6 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
+func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := &OpenAIGatewayService{}
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
+
+ t.Run("stateless image body stays unstuck", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+
+ require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
+ require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+ })
+
+ t.Run("prompt_cache_key is explicit", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+
+ got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
+ require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
+ require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+ })
+
+ t.Run("header overrides body", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+ c.Request.Header.Set("session_id", "header-session")
+
+ got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
+ require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
+ })
+}
+
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -1756,6 +1791,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
}
}
+func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
+
+ normalized, changed, err := normalizeOpenAICompactRequestBody(body)
+
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
+ require.True(t, gjson.GetBytes(normalized, "tools").Exists())
+ require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
+ require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
+ require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
+ require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
+ require.False(t, gjson.GetBytes(normalized, "store").Exists())
+ require.False(t, gjson.GetBytes(normalized, "stream").Exists())
+ require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
+}
+
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go
index 200547d4..47113d4d 100644
--- a/backend/internal/service/openai_images_test.go
+++ b/backend/internal/service/openai_images_test.go
@@ -11,6 +11,7 @@ import (
"strings"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
@@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
}
+func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/edits",
+ buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
+ )
+}
+
type openAIImageTestSSEEvent struct {
Name string
Data string
@@ -371,6 +391,124 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
}
+func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "X-Request-Id": []string{"req_img_apikey"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 6,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, "gpt-image-2", result.Model)
+ require.Equal(t, "gpt-image-2", result.UpstreamModel)
+
+ upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
+ require.True(t, ok)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
+ require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+}
+
+func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace background"))
+ imagePart, err := writer.CreateFormFile("image", "source.png")
+ require.NoError(t, err)
+ _, err = imagePart.Write([]byte("png-image-content"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "X-Request-Id": []string{"req_img_edit_apikey"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 7,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1/",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+
+ upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
+ require.True(t, ok)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
+ require.Contains(t, string(upstream.lastBody), `name="model"`)
+ require.Contains(t, string(upstream.lastBody), "gpt-image-2")
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+}
+
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
diff --git a/backend/internal/service/openai_passthrough_normalization_test.go b/backend/internal/service/openai_passthrough_normalization_test.go
new file mode 100644
index 00000000..492ff610
--- /dev/null
+++ b/backend/internal/service/openai_passthrough_normalization_test.go
@@ -0,0 +1,33 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
+
+ normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
+ require.NoError(t, err)
+ require.True(t, changed)
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
+ }
+ require.True(t, gjson.GetBytes(normalized, "stream").Bool())
+ require.False(t, gjson.GetBytes(normalized, "store").Bool())
+}
+
+func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
+
+ normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.False(t, gjson.GetBytes(normalized, "user").Exists())
+ require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
+ require.False(t, gjson.GetBytes(normalized, "stream").Exists())
+ require.False(t, gjson.GetBytes(normalized, "store").Exists())
+}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 8c0222e2..d1386b1b 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -1366,16 +1366,27 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
func shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled bool,
turn int,
- hasFunctionCallOutput bool,
+ signals ToolContinuationSignals,
currentPreviousResponseID string,
expectedPreviousResponseID string,
) bool {
- if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
+ if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
return false
}
if strings.TrimSpace(currentPreviousResponseID) != "" {
return false
}
+ if signals.HasFunctionCallOutputMissingCallID {
+ return false
+ }
+ // If the client already sent the actual tool-call context, treat this as
+ // a full replay / self-contained continuation payload rather than
+ // downgrading it into an inferred delta continuation. item_reference alone
+ // is not enough on the store=false WS path: it still needs a valid prior
+ // response anchor so upstream can resolve the referenced function_call.
+ if signals.HasToolCallContext {
+ return false
+ }
return strings.TrimSpace(expectedPreviousResponseID) != ""
}
@@ -2366,6 +2377,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("token is empty")
}
+ // 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
+ // 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
+ // 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
+ if s.settingService != nil {
+ if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
+ ctx = withOpenAIFastPolicyContext(ctx, settings)
+ }
+ }
+
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool
@@ -2524,6 +2544,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
+ // Apply OpenAI Fast Policy on the response.create frame using the same
+ // evaluator/normalize/scope rules as the HTTP entrypoints. This is the
+ // single integration point for all WS ingress turns (first + follow-up
+ // frames flow through here).
+ //
+ // Model fallback: parseClientPayload above rejects any frame whose
+ // "model" field is missing (line ~2493-2500), so by the time we
+ // reach this point upstreamModel is always derived from a non-empty
+ // per-frame model. The capturedSessionModel fallback used in the
+ // passthrough adapter is therefore not needed in this path.
+ policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
+ if policyErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
+ }
+ if blocked != nil {
+ // Send a Realtime-style error event to the client first, then
+ // signal the handler to close the connection with PolicyViolation.
+ // We intentionally do NOT forward this frame upstream.
+ //
+ // coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
+ // the underlying bufio writer before returning (write.go:42 →
+ // 307-311), and the subsequent close handshake re-acquires the
+ // same writeFrameMu, so the error event is guaranteed to reach
+ // the kernel send buffer before any close frame is queued.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes != nil {
+ writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancel()
+ }
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ blocked.Message,
+ blocked,
+ )
+ }
+ normalized = policyApplied
+
return openAIWSClientPayload{
payloadRaw: normalized,
rawForHash: trimmed,
@@ -3132,13 +3190,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
- hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
+ toolSignals := ToolContinuationSignals{
+ HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
+ }
+ if toolSignals.HasFunctionCallOutput {
+ var currentReqBody map[string]any
+ if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
+ toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
+ }
+ }
+ hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
// store=false + function_call_output 场景必须有续链锚点。
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
if shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled,
turn,
- hasFunctionCallOutput,
+ toolSignals,
currentPreviousResponseID,
expectedPrev,
) {
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
index 6bf9a9ff..30fd4142 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
}
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{captureConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 114,
+ Name: "openai-ingress-tool-context",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{captureConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 115,
+ Name: "openai-ingress-item-reference",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID")
+}
+
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go
index ff35cb01..c735f50a 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go
@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
name string
storeDisabled bool
turn int
- hasFunctionCallOutput bool
+ signals ToolContinuationSignals
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
- name: "infer_when_all_conditions_match",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: true,
+ name: "infer_when_all_conditions_match",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: true,
},
{
- name: "skip_when_store_enabled",
- storeDisabled: false,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_when_store_enabled",
+ storeDisabled: false,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
- name: "skip_on_first_turn",
- storeDisabled: true,
- turn: 1,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_on_first_turn",
+ storeDisabled: true,
+ turn: 1,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
- name: "skip_without_function_call_output",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: false,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_without_function_call_output",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
- hasFunctionCallOutput: true,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
- name: "skip_when_last_turn_response_id_missing",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "",
- want: false,
+ name: "skip_when_last_turn_response_id_missing",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "",
+ want: false,
},
{
- name: "trim_whitespace_before_judgement",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: " resp_2 ",
- want: true,
+ name: "trim_whitespace_before_judgement",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: " resp_2 ",
+ want: true,
+ },
+ {
+ name: "skip_when_tool_call_context_already_present",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
+ expectedPrevious: "resp_2",
+ want: false,
+ },
+ {
+ name: "infer_when_only_item_reference_covers_call_ids",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
+ expectedPrevious: "resp_2",
+ want: true,
+ },
+ {
+ name: "skip_when_function_call_output_missing_call_id",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
+ expectedPrevious: "resp_2",
+ want: false,
},
}
@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
- tt.hasFunctionCallOutput,
+ tt.signals,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
index 66e5db93..f3936de1 100644
--- a/backend/internal/service/openai_ws_protocol_forward_test.go
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
+ nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
index cda2e351..3dbb199a 100644
--- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go
+++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
+// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
+// every client→upstream frame through the OpenAI Fast Policy. It is the
+// passthrough-relay equivalent of the parseClientPayload integration in the
+// ingress session path. filter returns:
+// - newPayload, nil, nil: forward the (possibly mutated) payload
+// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
+// event via onBlock and surfaces a transport-level error so the relay
+// stops reading from the client.
+// - _, _, err: a transport error other than block.
+type openAIWSPolicyEnforcingFrameConn struct {
+ inner openaiwsv2.FrameConn
+ filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
+ onBlock func(blocked *OpenAIFastBlockedError)
+}
+
+var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
+
+func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c == nil || c.inner == nil {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ msgType, payload, err := c.inner.ReadFrame(ctx)
+ if err != nil {
+ return msgType, payload, err
+ }
+ if c.filter == nil {
+ return msgType, payload, nil
+ }
+ updated, blocked, filterErr := c.filter(msgType, payload)
+ if filterErr != nil {
+ return msgType, payload, filterErr
+ }
+ if blocked != nil {
+ if c.onBlock != nil {
+ c.onBlock(blocked)
+ }
+ return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
+ }
+ return msgType, updated, nil
+}
+
+func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if c == nil || c.inner == nil {
+ return errOpenAIWSConnClosed
+ }
+ return c.inner.WriteFrame(ctx, msgType, payload)
+}
+
+func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
+ if c == nil || c.inner == nil {
+ return nil
+ }
+ return c.inner.Close()
+}
+
+// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
+// model name that should be passed to evaluateOpenAIFastPolicy for a single
+// passthrough WS frame. Mirrors the HTTP-side normalization
+// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
+// matches model whitelists identically.
+func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
+ if account == nil || len(payload) == 0 {
+ return ""
+ }
+ original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
+ if original == "" {
+ return ""
+ }
+ return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
+}
+
+// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
+// derived from a session.update frame's session.model field. Returns "" when
+// the frame is not a session.update event or carries no session.model. Used
+// by the per-frame policy filter (client→upstream direction) to keep
+// capturedSessionModel in sync with the session-level model the client may
+// rotate mid-session.
+//
+// Realtime / Responses WS lets the client change the session model after
+// the WS handshake via:
+//
+// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
+//
+// If we only capture the model from the very first frame, a client can ship
+// gpt-4o on the first response.create (whitelisted as pass), then
+// session.update to gpt-5.5, then send response.create without "model" so
+// the per-frame resolver returns "" and the stale capturedSessionModel falls
+// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
+func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
+ if account == nil || len(payload) == 0 {
+ return ""
+ }
+ frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
+ if frameType != "session.update" {
+ return ""
+ }
+ original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
+ if original == "" {
+ return ""
+ }
+ return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
+}
+
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
- requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage),
)
+ // Apply OpenAI Fast Policy on the first response.create frame. Subsequent
+ // frames are filtered via a wrapping FrameConn below so every client→
+ // upstream frame goes through the same policy evaluator/normalize/scope as
+ // HTTP entrypoints.
+ //
+ // We capture the session-level model from the first frame here so the
+ // per-frame filter (below) can fall back to it when a follow-up frame
+ // omits "model" — Realtime clients are allowed to send response.create
+ // without re-stating the model, in which case the upstream uses the model
+ // negotiated at session.update time. Without this fallback, an empty
+ // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
+ // silently passed through, defeating the policy on every frame after
+ // the first.
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
+ updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
+ if policyErr != nil {
+ return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
+ }
+ if blocked != nil {
+ // coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
+ // writeFrameMu, writes the entire frame, and Flushes the underlying
+ // bufio writer before returning (write.go:42 → write.go:307-311).
+ // The subsequent close handshake re-acquires the same writeFrameMu
+ // to send the close frame, so the error event is guaranteed to
+ // reach the kernel send buffer before any close frame is queued.
+ // No explicit flush hop is required here.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes != nil {
+ writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancelWrite()
+ }
+ return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
+ }
+ firstClientMessage = updatedFirst
+
+ // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
+ // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
+ // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
+ // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
+ // 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
+ //
+ // 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
+ // 同一连接的不同 response.create 帧上发送不同 service_tier(参考
+ // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
+ // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
+ // goroutine)和 OnTurnComplete / final result(runUpstreamToClient
+ // goroutine)之间同步当前 turn 的 service_tier。
+ // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
+ // 可直接 Store/Load 而无需额外封装。
+ var requestServiceTierPtr atomic.Pointer[string]
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
+
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
}
completedTurns := atomic.Int32{}
+ policyClientConn := &openAIWSPolicyEnforcingFrameConn{
+ inner: &openAIWSClientFrameConn{conn: clientConn},
+ // 注意线程安全:filter 仅在 runClientToUpstream 这一条
+ // goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
+ // capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
+ // 加锁/原子化。
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ // 在评估策略前先刷新 capturedSessionModel:客户端可能通过
+ // session.update 修改 session-level model(Realtime /
+ // Responses WS 协议允许),如果不刷新就会出现
+ // "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
+ // → 不带 model 的 response.create fallback 到 gpt-4o" 的
+ // 绕过路径。这里只看 session.update 事件中的 session.model
+ // 字段,response.create 自己的 model 仍然由其本帧字段决定。
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ // Per-frame model first; if the client omits "model" on a
+ // follow-up frame (legal in Realtime), fall back to the
+ // session-level model captured from the first frame so the
+ // model whitelist still resolves. An empty model would miss
+ // any whitelist and silently fall back to pass.
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
+ // 多轮 passthrough billing:仅在成功(non-block / non-err)
+ // 的 response.create 帧上更新 requestServiceTierPtr,使用
+ // filter 处理后的 payload,与首帧 policy-after-extract 语义
+ // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
+ // - 非 response.create 帧(response.cancel /
+ // conversation.item.create / session.update 等)不携带
+ // per-response service_tier,不应覆盖前一轮值。
+ // - blocked != nil:该帧不会发送上游,billing tier 应保持
+ // 上一轮值。
+ // - policyErr != nil:异常路径,保持上一轮值。
+ // - 不带 service_tier 的 response.create 会让
+ // extractOpenAIServiceTierFromBody 返回 nil;这里有意
+ // 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
+ // service_tier 时按 default 处理,billing 应如实反映。
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ },
+ onBlock: func(blocked *OpenAIFastBlockedError) {
+ // See note above on Conn.Write being synchronous w.r.t. flush;
+ // no explicit flush is required to ensure the error event lands
+ // before the close frame.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes == nil {
+ return
+ }
+ writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancel()
+ },
+ }
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
- ClientConn: &openAIWSClientFrameConn{conn: clientConn},
+ ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
- ServiceTier: requestServiceTier,
+ ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
- ServiceTier: requestServiceTier,
+ ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go
index 08a10a02..44ec1ad1 100644
--- a/backend/internal/service/ops_cleanup_service.go
+++ b/backend/internal/service/ops_cleanup_service.go
@@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
)
}
+// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
+// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
+// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
+// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
+//
+// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE":
+// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
+// - 无 WAL 写入、无后续 VACUUM 压力
+// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
+func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
+ if days < 0 {
+ return time.Time{}, false, false
+ }
+ if days == 0 {
+ return time.Time{}, true, true
+ }
+ return now.AddDate(0, 0, -days), false, true
+}
+
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
out := opsCleanupDeletedCounts{}
if s == nil || s.db == nil || s.cfg == nil {
@@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
now := time.Now().UTC()
- // Error-like tables: error logs / retry attempts / alert events.
- if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
+ // runOne 把"truncate? cutoff? batched delete?"封装到一处,
+ // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
+ runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
+ if truncate {
+ return truncateOpsTable(ctx, s.db, table)
+ }
+ return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
+ }
+
+ // Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
if err != nil {
return out, err
}
out.errorLogs = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
if err != nil {
return out, err
}
out.retryAttempts = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
if err != nil {
return out, err
}
out.alertEvents = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
if err != nil {
return out, err
}
out.systemLogs = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
if err != nil {
return out, err
}
@@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Minute-level metrics snapshots.
- if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
if err != nil {
return out, err
}
@@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Pre-aggregation tables (hourly/daily).
- if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
if err != nil {
return out, err
}
out.hourlyPreagg = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
+ n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
if err != nil {
return out, err
}
@@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil {
// If ops tables aren't present yet (partial deployments), treat as no-op.
- if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
+ if isMissingRelationError(err) {
return total, nil
}
return total, err
@@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
return total, nil
}
+// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
+//
+// 与 deleteOldRowsByID 的差异:
+// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
+// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
+// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
+//
+// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。
+func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
+ if db == nil {
+ return 0, nil
+ }
+ var count int64
+ if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
+ if isMissingRelationError(err) {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("count %s: %w", table, err)
+ }
+ if count == 0 {
+ return 0, nil
+ }
+ if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
+ if isMissingRelationError(err) {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("truncate %s: %w", table, err)
+ }
+ return count, nil
+}
+
+// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
+func isMissingRelationError(err error) bool {
+ if err == nil {
+ return false
+ }
+ s := strings.ToLower(err.Error())
+ return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
+}
+
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil {
return nil, false
diff --git a/backend/internal/service/ops_cleanup_service_test.go b/backend/internal/service/ops_cleanup_service_test.go
new file mode 100644
index 00000000..86657d27
--- /dev/null
+++ b/backend/internal/service/ops_cleanup_service_test.go
@@ -0,0 +1,64 @@
+package service
+
+import (
+ "testing"
+ "time"
+)
+
+func TestOpsCleanupPlan(t *testing.T) {
+ now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
+
+ cases := []struct {
+ name string
+ days int
+ wantOK bool
+ wantTruncate bool
+ wantCutoff time.Time
+ }{
+ {name: "negative skips", days: -1, wantOK: false},
+ {name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
+ {name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
+ if ok != tc.wantOK {
+ t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if truncate != tc.wantTruncate {
+ t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
+ }
+ if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
+ t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
+ }
+ })
+ }
+}
+
+func TestIsMissingRelationError(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {name: "nil is not missing", err: nil, want: false},
+ {name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
+ {name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
+ {name: "non-matching error", err: fakeErr("connection refused"), want: false},
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got := isMissingRelationError(tc.err); got != tc.want {
+ t.Fatalf("got %v, want %v", got, tc.want)
+ }
+ })
+ }
+}
+
+type fakeErr string
+
+func (e fakeErr) Error() string { return string(e) }
diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go
index 5871166c..ecc3a94b 100644
--- a/backend/internal/service/ops_settings.go
+++ b/backend/internal/service/ops_settings.go
@@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
}
- if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
+ // 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
+ // 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
+ if cfg.DataRetention.ErrorLogRetentionDays < 0 {
cfg.DataRetention.ErrorLogRetentionDays = 30
}
- if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
+ if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
cfg.DataRetention.MinuteMetricsRetentionDays = 30
}
- if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
+ if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
cfg.DataRetention.HourlyMetricsRetentionDays = 30
}
// Normalize auto refresh interval (default 30 seconds)
@@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil {
return errors.New("invalid config")
}
- if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
- return errors.New("error_log_retention_days must be between 1 and 365")
+ // 保留天数:0 表示每次清理全部,1-365 表示按天数保留。
+ if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
+ return errors.New("error_log_retention_days must be between 0 and 365")
}
- if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
- return errors.New("minute_metrics_retention_days must be between 1 and 365")
+ if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
+ return errors.New("minute_metrics_retention_days must be between 0 and 365")
}
- if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
- return errors.New("hourly_metrics_retention_days must be between 1 and 365")
+ if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
+ return errors.New("hourly_metrics_retention_days must be between 0 and 365")
}
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index c6167447..5df69aea 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -269,7 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
- s.applyAffiliateRebateForOrder(ctx, o)
+ if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
+ return err
+ }
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
@@ -283,7 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
- s.applyAffiliateRebateForOrder(ctx, o)
+ if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
+ return err
+ }
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
@@ -361,12 +365,12 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
-func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) {
+func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
- return
+ return nil
}
if s.affiliateService == nil {
- return
+ return nil
}
tx, err := s.entClient.Tx(ctx)
@@ -374,7 +378,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
})
- return
+ return fmt.Errorf("begin affiliate rebate tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
@@ -384,10 +388,10 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
- return
+ return fmt.Errorf("claim affiliate rebate audit: %w", err)
}
if !claimed {
- return
+ return nil
}
rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
@@ -395,7 +399,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
- return
+ return fmt.Errorf("accrue affiliate rebate: %w", err)
}
if rebateAmount <= 0 {
@@ -406,14 +410,15 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
- return
+ return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
+ return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
- return
+ return nil
}
if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
@@ -423,14 +428,16 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": err.Error(),
})
- return
+ return fmt.Errorf("update affiliate rebate applied audit: %w", err)
}
if err := tx.Commit(); err != nil {
s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
"error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
})
+ return fmt.Errorf("commit affiliate rebate tx: %w", err)
}
+ return nil
}
func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
@@ -444,11 +451,11 @@ func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, clien
})
rows, err := client.QueryContext(ctx, `
INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
-SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW()
+SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
WHERE NOT EXISTS (
SELECT 1
FROM payment_audit_logs
- WHERE order_id = $1
+ WHERE order_id = $1::text
AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
)
ON CONFLICT (order_id, action) DO NOTHING
diff --git a/backend/internal/service/scheduler_cache.go b/backend/internal/service/scheduler_cache.go
index f36135e0..f9794c82 100644
--- a/backend/internal/service/scheduler_cache.go
+++ b/backend/internal/service/scheduler_cache.go
@@ -59,6 +59,8 @@ type SchedulerCache interface {
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
// TryLockBucket 尝试获取分桶重建锁。
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
+ // UnlockBucket 释放分桶重建锁。
+ UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
// ListBuckets 返回已注册的分桶集合。
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
// GetOutboxWatermark 读取 outbox 水位。
diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go
index 5c0b289b..0b32c2ad 100644
--- a/backend/internal/service/scheduler_snapshot_hydration_test.go
+++ b/backend/internal/service/scheduler_snapshot_hydration_test.go
@@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
return true, nil
}
+func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
+ return nil
+}
+
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
return nil, nil
}
diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go
index 62b6993d..a68cdf0c 100644
--- a/backend/internal/service/scheduler_snapshot_service.go
+++ b/backend/internal/service/scheduler_snapshot_service.go
@@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
if !ok {
return nil
}
+ defer func() {
+ _ = s.cache.UnlockBucket(ctx, bucket)
+ }()
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index f871ee85..2bae686a 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second
// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL)
type cachedGatewayForwardingSettings struct {
- fingerprintUnification bool
- metadataPassthrough bool
- cchSigning bool
- expiresAt int64 // unix nano
+ fingerprintUnification bool
+ metadataPassthrough bool
+ cchSigning bool
+ anthropicCacheTTL1hInjection bool
+ expiresAt int64 // unix nano
}
var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings
@@ -1175,6 +1176,24 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
+ if settings.AffiliateRebateFreezeHours < 0 {
+ settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
+ }
+ if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
+ settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
+ }
+ updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
+ if settings.AffiliateRebateDurationDays < 0 {
+ settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
+ }
+ if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
+ settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
+ }
+ updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
+ if settings.AffiliateRebatePerInviteeCap < 0 {
+ settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
+ }
+ updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
@@ -1227,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
@@ -1287,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
})
gatewayForwardingSF.Forget("gateway_forwarding")
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: settings.EnableFingerprintUnification,
- metadataPassthrough: settings.EnableMetadataPassthrough,
- cchSigning: settings.EnableCCHSigning,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ fingerprintUnification: settings.EnableFingerprintUnification,
+ metadataPassthrough: settings.EnableMetadataPassthrough,
+ cchSigning: settings.EnableCCHSigning,
+ anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
@@ -1397,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
return false
}
-// GetGatewayForwardingSettings returns cached gateway forwarding settings.
-// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
-// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
-func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
+type gatewayForwardingSettingsResult struct {
+ fp, mp, cch, cacheTTL1h bool
+}
+
+func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
- return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
+ return gatewayForwardingSettingsResult{
+ fp: cached.fingerprintUnification,
+ mp: cached.metadataPassthrough,
+ cch: cached.cchSigning,
+ cacheTTL1h: cached.anthropicCacheTTL1hInjection,
+ }
}
}
- type gwfResult struct {
- fp, mp, cch bool
- }
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
- return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
+ return gatewayForwardingSettingsResult{
+ fp: cached.fingerprintUnification,
+ mp: cached.metadataPassthrough,
+ cch: cached.cchSigning,
+ cacheTTL1h: cached.anthropicCacheTTL1hInjection,
+ }, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
@@ -1421,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
SettingKeyEnableFingerprintUnification,
SettingKeyEnableMetadataPassthrough,
SettingKeyEnableCCHSigning,
+ SettingKeyEnableAnthropicCacheTTL1hInjection,
})
if err != nil {
slog.Warn("failed to get gateway forwarding settings", "error", err)
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: true,
- metadataPassthrough: false,
- cchSigning: false,
- expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
+ fingerprintUnification: true,
+ metadataPassthrough: false,
+ cchSigning: false,
+ anthropicCacheTTL1hInjection: false,
+ expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
})
- return gwfResult{true, false, false}, nil
+ return gatewayForwardingSettingsResult{fp: true}, nil
}
fp := true
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
@@ -1438,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
}
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
cch := values[SettingKeyEnableCCHSigning] == "true"
+ cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: fp,
- metadataPassthrough: mp,
- cchSigning: cch,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ fingerprintUnification: fp,
+ metadataPassthrough: mp,
+ cchSigning: cch,
+ anthropicCacheTTL1hInjection: cacheTTL1h,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
- return gwfResult{fp, mp, cch}, nil
+ return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
})
- if r, ok := val.(gwfResult); ok {
- return r.fp, r.mp, r.cch
+ if r, ok := val.(gatewayForwardingSettingsResult); ok {
+ return r
}
- return true, false, false // fail-open defaults
+ return gatewayForwardingSettingsResult{fp: true}
+}
+
+// GetGatewayForwardingSettings returns cached gateway forwarding settings.
+// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
+// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
+func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
+ result := s.getGatewayForwardingSettingsCached(ctx)
+ return result.fp, result.mp, result.cch
+}
+
+// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。
+func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool {
+ return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
@@ -1512,6 +1558,54 @@ func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) floa
return clampAffiliateRebateRate(rate)
}
+// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
+// 返回 0 表示不冻结(向后兼容)。
+func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
+ if err != nil {
+ return AffiliateRebateFreezeHoursDefault
+ }
+ hours, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil || hours < 0 {
+ return AffiliateRebateFreezeHoursDefault
+ }
+ if hours > AffiliateRebateFreezeHoursMax {
+ return AffiliateRebateFreezeHoursMax
+ }
+ return hours
+}
+
+// GetAffiliateRebateDurationDays 返回返利有效期(天)。
+// 返回 0 表示永久有效。
+func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
+ if err != nil {
+ return AffiliateRebateDurationDaysDefault
+ }
+ days, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil || days < 0 {
+ return AffiliateRebateDurationDaysDefault
+ }
+ if days > AffiliateRebateDurationDaysMax {
+ return AffiliateRebateDurationDaysMax
+ }
+ return days
+}
+
+// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
+// 返回 0 表示无上限。
+func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
+ if err != nil {
+ return AffiliateRebatePerInviteeCapDefault
+ }
+ cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
+ if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
+ return AffiliateRebatePerInviteeCapDefault
+ }
+ return cap
+}
+
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -1755,6 +1849,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
+ SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
+ SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
+ SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
SettingKeyDefaultUserRPMLimit: "0",
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
@@ -1811,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
- SettingKeyAllowUngroupedKeyScheduling: "false",
- SettingPaymentVisibleMethodAlipaySource: "",
- SettingPaymentVisibleMethodWxpaySource: "",
- SettingPaymentVisibleMethodAlipayEnabled: "false",
- SettingPaymentVisibleMethodWxpayEnabled: "false",
- openAIAdvancedSchedulerSettingKey: "false",
+ SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
+ SettingPaymentVisibleMethodAlipaySource: "",
+ SettingPaymentVisibleMethodWxpaySource: "",
+ SettingPaymentVisibleMethodAlipayEnabled: "false",
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -1890,6 +1988,21 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.AffiliateRebateRate = AffiliateRebateRateDefault
}
+ if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
+ if freezeHours > AffiliateRebateFreezeHoursMax {
+ freezeHours = AffiliateRebateFreezeHoursMax
+ }
+ result.AffiliateRebateFreezeHours = freezeHours
+ }
+ if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
+ if durationDays > AffiliateRebateDurationDaysMax {
+ durationDays = AffiliateRebateDurationDaysMax
+ }
+ result.AffiliateRebateDurationDays = durationDays
+ }
+ if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
+ result.AffiliateRebatePerInviteeCap = perInviteeCap
+ }
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -2144,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
+ result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
@@ -3175,6 +3289,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
+// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
+func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
+ if err != nil {
+ if errors.Is(err, ErrSettingNotFound) {
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+ return nil, fmt.Errorf("get openai fast policy settings: %w", err)
+ }
+ if value == "" {
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+
+ var settings OpenAIFastPolicySettings
+ if err := json.Unmarshal([]byte(value), &settings); err != nil {
+ // JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
+ // 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
+ // 行为时定位到 settings 表里的脏数据。
+ slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
+ "error", err,
+ "key", SettingKeyOpenAIFastPolicySettings)
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+
+ return &settings, nil
+}
+
+// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
+func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
+ if settings == nil {
+ return fmt.Errorf("settings cannot be nil")
+ }
+
+ validActions := map[string]bool{
+ BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
+ }
+ validScopes := map[string]bool{
+ BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
+ }
+ validTiers := map[string]bool{
+ OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
+ }
+
+ for i, rule := range settings.Rules {
+ tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
+ if tier == "" {
+ tier = OpenAIFastTierAny
+ }
+ if !validTiers[tier] {
+ return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
+ }
+ settings.Rules[i].ServiceTier = tier
+ if !validActions[rule.Action] {
+ return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
+ }
+ if !validScopes[rule.Scope] {
+ return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
+ }
+ for j, pattern := range rule.ModelWhitelist {
+ trimmed := strings.TrimSpace(pattern)
+ if trimmed == "" {
+ return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
+ }
+ settings.Rules[i].ModelWhitelist[j] = trimmed
+ }
+ if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
+ return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
+ }
+ }
+
+ data, err := json.Marshal(settings)
+ if err != nil {
+ return fmt.Errorf("marshal openai fast policy settings: %w", err)
+ }
+
+ return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
+}
+
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 70d8efc3..41c01cca 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -104,12 +104,15 @@ type SystemSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
- DefaultConcurrency int
- DefaultBalance float64
- AffiliateEnabled bool
- AffiliateRebateRate float64
- DefaultUserRPMLimit int
- DefaultSubscriptions []DefaultSubscriptionSetting
+ DefaultConcurrency int
+ DefaultBalance float64
+ AffiliateEnabled bool
+ AffiliateRebateRate float64
+ AffiliateRebateFreezeHours int
+ AffiliateRebateDurationDays int
+ AffiliateRebatePerInviteeCap float64
+ DefaultUserRPMLimit int
+ DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -146,9 +149,10 @@ type SystemSettings struct {
BackendModeEnabled bool
// Gateway forwarding behavior
- EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
- EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
- EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+ EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
+ EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
+ EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+ EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
@@ -402,3 +406,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
},
}
}
+
+// OpenAI Fast Policy 策略常量
+// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
+// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
+// - "flex":低优先级模式
+// - 省略:normal 默认
+//
+// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
+// anthropic-beta header 换成 body 的 service_tier 字段。
+const (
+ OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
+ OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
+ OpenAIFastTierFlex = "flex" // 仅匹配 flex
+)
+
+// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
+type OpenAIFastPolicyRule struct {
+ ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
+ Action string `json:"action"` // "pass" | "filter" | "block"
+ Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
+ ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
+ ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
+ FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
+ FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
+}
+
+// OpenAIFastPolicySettings OpenAI fast 策略配置
+type OpenAIFastPolicySettings struct {
+ Rules []OpenAIFastPolicyRule `json:"rules"`
+}
+
+// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
+// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
+// 让上游按 normal 优先级处理。
+//
+// 为什么 ModelWhitelist 为空(=对所有模型生效):
+// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
+// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
+// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
+// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
+// 模型,可在 admin UI 中显式配置 model_whitelist。
+func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
+ return &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{
+ {
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{},
+ FallbackAction: BetaPolicyActionPass,
+ },
+ },
+ }
+}
diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go
new file mode 100644
index 00000000..4430cf81
--- /dev/null
+++ b/backend/internal/service/vertex_service_account.go
@@ -0,0 +1,345 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+const (
+ vertexDefaultLocation = "us-central1"
+ vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
+ vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
+ vertexServiceAccountCacheSkew = 5 * time.Minute
+ vertexLockWaitTime = 200 * time.Millisecond
+ vertexAnthropicVersion = "vertex-2023-10-16"
+)
+
+var (
+ vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
+ vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
+ vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
+)
+
+type vertexServiceAccountKey struct {
+ Type string `json:"type"`
+ ProjectID string `json:"project_id"`
+ PrivateKeyID string `json:"private_key_id"`
+ PrivateKey string `json:"private_key"`
+ ClientEmail string `json:"client_email"`
+ TokenURI string `json:"token_uri"`
+}
+
+type vertexTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ Error string `json:"error"`
+ ErrorDesc string `json:"error_description"`
+}
+
+func (a *Account) IsVertexServiceAccount() bool {
+ return a != nil && a.Type == AccountTypeServiceAccount
+}
+
+func (a *Account) VertexProjectID() string {
+ if a == nil {
+ return ""
+ }
+ if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
+ return v
+ }
+ key, err := parseVertexServiceAccountKey(a)
+ if err == nil {
+ return strings.TrimSpace(key.ProjectID)
+ }
+ return ""
+}
+
+func (a *Account) VertexLocation(model string) string {
+ if a == nil {
+ return vertexDefaultLocation
+ }
+ if model != "" && a.Credentials != nil {
+ if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
+ if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
+ return strings.TrimSpace(loc)
+ }
+ }
+ }
+ if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
+ return v
+ }
+ return vertexDefaultLocation
+}
+
+func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
+ if account == nil || account.Credentials == nil {
+ return nil, errors.New("service account credentials not configured")
+ }
+
+ if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
+ return parseVertexServiceAccountJSON([]byte(raw))
+ }
+ if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
+ return parseVertexServiceAccountJSON([]byte(raw))
+ }
+ if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
+ b, _ := json.Marshal(nested)
+ return parseVertexServiceAccountJSON(b)
+ }
+ if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
+ b, _ := json.Marshal(nested)
+ return parseVertexServiceAccountJSON(b)
+ }
+ return nil, errors.New("service_account_json not found in credentials")
+}
+
+func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
+ var key vertexServiceAccountKey
+ if err := json.Unmarshal(raw, &key); err != nil {
+ return nil, fmt.Errorf("invalid service account json: %w", err)
+ }
+ if strings.TrimSpace(key.ClientEmail) == "" {
+ return nil, errors.New("service account json missing client_email")
+ }
+ if strings.TrimSpace(key.PrivateKey) == "" {
+ return nil, errors.New("service account json missing private_key")
+ }
+ if strings.TrimSpace(key.ProjectID) == "" {
+ return nil, errors.New("service account json missing project_id")
+ }
+ // Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
+ key.TokenURI = vertexDefaultTokenURL
+ return &key, nil
+}
+
+func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
+ fingerprint := ""
+ if key != nil {
+ sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
+ fingerprint = hex.EncodeToString(sum[:8])
+ }
+ if fingerprint == "" && account != nil {
+ fingerprint = fmt.Sprintf("account:%d", account.ID)
+ }
+ return "vertex:service_account:" + fingerprint
+}
+
+// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
+// using the shared cache and distributed lock to avoid redundant exchanges.
+func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
+ key, err := parseVertexServiceAccountKey(account)
+ if err != nil {
+ return "", err
+ }
+ cacheKey := vertexServiceAccountCacheKey(account, key)
+
+ if cache != nil {
+ if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+
+ locked := false
+ if cache != nil {
+ var lockErr error
+ locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
+ } else if lockErr != nil {
+ slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
+ } else {
+ time.Sleep(vertexLockWaitTime)
+ if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+ }
+
+ accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
+ if err != nil {
+ return "", err
+ }
+ if cache != nil {
+ _ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
+ }
+ return accessToken, nil
+}
+
+func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
+ now := time.Now()
+ claims := jwt.MapClaims{
+ "iss": key.ClientEmail,
+ "scope": vertexCloudPlatformScope,
+ "aud": key.TokenURI,
+ "iat": now.Unix(),
+ "exp": now.Add(time.Hour).Unix(),
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ if strings.TrimSpace(key.PrivateKeyID) != "" {
+ token.Header["kid"] = key.PrivateKeyID
+ }
+ privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
+ if err != nil {
+ return "", 0, fmt.Errorf("parse service account private key: %w", err)
+ }
+ assertion, err := token.SignedString(privateKey)
+ if err != nil {
+ return "", 0, fmt.Errorf("sign service account assertion: %w", err)
+ }
+
+ values := url.Values{}
+ values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
+ values.Set("assertion", assertion)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
+ if err != nil {
+ return "", 0, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", 0, fmt.Errorf("service account token request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ var parsed vertexTokenResponse
+ _ = json.Unmarshal(body, &parsed)
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ msg := strings.TrimSpace(parsed.ErrorDesc)
+ if msg == "" {
+ msg = strings.TrimSpace(parsed.Error)
+ }
+ if msg == "" {
+ msg = string(bytes.TrimSpace(body))
+ }
+ return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
+ }
+ if strings.TrimSpace(parsed.AccessToken) == "" {
+ return "", 0, errors.New("service account token response missing access_token")
+ }
+ ttl := time.Duration(parsed.ExpiresIn) * time.Second
+ if ttl <= 0 {
+ ttl = time.Hour
+ }
+ if ttl > vertexServiceAccountCacheSkew {
+ ttl -= vertexServiceAccountCacheSkew
+ }
+ return parsed.AccessToken, ttl, nil
+}
+
+func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
+ projectID = strings.TrimSpace(projectID)
+ location = strings.TrimSpace(location)
+ model = strings.TrimSpace(model)
+ action = strings.TrimSpace(action)
+ if projectID == "" {
+ return "", errors.New("vertex project_id is required")
+ }
+ if location == "" {
+ location = vertexDefaultLocation
+ }
+ if !vertexLocationPattern.MatchString(location) {
+ return "", fmt.Errorf("invalid vertex location: %s", location)
+ }
+ if model == "" {
+ return "", errors.New("vertex model is required")
+ }
+ switch action {
+ case "generateContent", "streamGenerateContent", "countTokens":
+ default:
+ return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
+ }
+ host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
+ if location == "global" {
+ host = "aiplatform.googleapis.com"
+ }
+ u := fmt.Sprintf(
+ "https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+ host,
+ url.PathEscape(projectID),
+ url.PathEscape(location),
+ url.PathEscape(model),
+ action,
+ )
+ if stream {
+ u += "?alt=sse"
+ }
+ return u, nil
+}
+
+func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
+ projectID = strings.TrimSpace(projectID)
+ location = strings.TrimSpace(location)
+ model = strings.TrimSpace(model)
+ if projectID == "" {
+ return "", errors.New("vertex project_id is required")
+ }
+ if location == "" {
+ location = vertexDefaultLocation
+ }
+ if !vertexLocationPattern.MatchString(location) {
+ return "", fmt.Errorf("invalid vertex location: %s", location)
+ }
+ if model == "" {
+ return "", errors.New("vertex model is required")
+ }
+ action := "rawPredict"
+ if stream {
+ action = "streamRawPredict"
+ }
+ host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
+ if location == "global" {
+ host = "aiplatform.googleapis.com"
+ }
+ escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
+ return fmt.Sprintf(
+ "https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ host,
+ url.PathEscape(projectID),
+ url.PathEscape(location),
+ escapedModel,
+ action,
+ ), nil
+}
+
+func normalizeVertexAnthropicModelID(model string) string {
+ model = strings.TrimSpace(model)
+ if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
+ return model
+ }
+ if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
+ return m[1] + "@" + m[2]
+ }
+ return model
+}
+
+func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
+ }
+ delete(payload, "model")
+ payload["anthropic_version"] = vertexAnthropicVersion
+ return json.Marshal(payload)
+}
diff --git a/backend/internal/service/vertex_service_account_test.go b/backend/internal/service/vertex_service_account_test.go
new file mode 100644
index 00000000..519f5b2f
--- /dev/null
+++ b/backend/internal/service/vertex_service_account_test.go
@@ -0,0 +1,77 @@
+package service
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestBuildVertexGeminiURL(t *testing.T) {
+ got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
+}
+
+func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
+ got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
+}
+
+func TestBuildVertexAnthropicURL(t *testing.T) {
+ got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
+ require.NoError(t, err)
+ require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
+}
+
+func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
+ got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
+}
+
+func TestNormalizeVertexAnthropicModelID(t *testing.T) {
+ require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
+ require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
+ require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
+}
+
+func TestBuildVertexAnthropicRequestBody(t *testing.T) {
+ got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
+ require.NoError(t, err)
+ require.Equal(t, "", gjson.GetBytes(got, "model").String())
+ require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
+ require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
+ require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
+}
+
+func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
+ _, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid vertex location")
+}
+
+func TestParseVertexServiceAccountKey(t *testing.T) {
+ raw := `{
+ "type": "service_account",
+ "project_id": "vertex-proj",
+ "private_key_id": "kid",
+ "private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
+ "client_email": "svc@vertex-proj.iam.gserviceaccount.com"
+ }`
+ account := &Account{
+ Type: AccountTypeServiceAccount,
+ Platform: PlatformGemini,
+ Credentials: map[string]any{
+ "service_account_json": raw,
+ },
+ }
+ key, err := parseVertexServiceAccountKey(account)
+ require.NoError(t, err)
+ require.Equal(t, "vertex-proj", key.ProjectID)
+ require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
+ require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
+ require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index d1a50913..1b7cb7ac 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -407,12 +407,28 @@ func ProvideBillingCacheService(
return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
}
+// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
+func ProvideAPIKeyService(
+ apiKeyRepo APIKeyRepository,
+ userRepo UserRepository,
+ groupRepo GroupRepository,
+ userSubRepo UserSubscriptionRepository,
+ userGroupRateRepo UserGroupRateRepository,
+ cache APIKeyCache,
+ cfg *config.Config,
+ billingCacheService *BillingCacheService,
+) *APIKeyService {
+ svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
+ svc.SetRateLimitCacheInvalidator(billingCacheService)
+ return svc
+}
+
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
- NewAPIKeyService,
+ ProvideAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
diff --git a/backend/migrations/133_affiliate_rebate_freeze.sql b/backend/migrations/133_affiliate_rebate_freeze.sql
new file mode 100644
index 00000000..b87d59b7
--- /dev/null
+++ b/backend/migrations/133_affiliate_rebate_freeze.sql
@@ -0,0 +1,17 @@
+-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
+
+-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
+-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
+ALTER TABLE user_affiliate_ledger
+ ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
+
+COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
+
+-- 3) Partial index for efficient thaw queries (only rows still frozen).
+CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
+ ON user_affiliate_ledger (user_id, frozen_until)
+ WHERE frozen_until IS NOT NULL;
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
index a484d7ed..07a68c03 100644
--- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts affiliate code when completing linuxdo oauth registration', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration(
+ 'invite-code',
+ {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ },
+ ' AFF123 '
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'AFF123',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
it('posts oidc invitation completion with adoption decisions', async () => {
const { completeOIDCOAuthRegistration } = await import('@/api/auth')
@@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => {
})
})
+ it('posts affiliate code when creating pending wechat oauth account', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount(
+ 'invite-code',
+ {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ },
+ 'WXAFF'
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'WXAFF',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
it('classifies oauth completion results as login or bind', async () => {
const { getOAuthCompletionKind } = await import('@/api/auth')
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index a146f1f7..8a127793 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: {
* @returns Success confirmation
*/
export async function bulkUpdate(
- accountIds: number[],
- updates: Record
+ accountIdsOrPayload: number[] | Record,
+ updates?: Record
): Promise<{
success: number
failed: number
@@ -379,16 +379,19 @@ export async function bulkUpdate(
failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }>
}> {
+ const payload = Array.isArray(accountIdsOrPayload)
+ ? {
+ account_ids: accountIdsOrPayload,
+ ...(updates ?? {})
+ }
+ : accountIdsOrPayload
const { data } = await apiClient.post<{
success: number
failed: number
success_ids?: number[]
failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }>
- }>('/admin/accounts/bulk-update', {
- account_ids: accountIds,
- ...updates
- })
+ }>('/admin/accounts/bulk-update', payload)
return data
}
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index cf8626fc..b887355a 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -309,6 +309,9 @@ export interface SystemSettings {
// Default settings
default_balance: number;
affiliate_rebate_rate: number;
+ affiliate_rebate_freeze_hours: number;
+ affiliate_rebate_duration_days: number;
+ affiliate_rebate_per_invitee_cap: number;
default_concurrency: number;
default_user_rpm_limit: number;
default_subscriptions: DefaultSubscriptionSetting[];
@@ -437,6 +440,7 @@ export interface SystemSettings {
enable_fingerprint_unification: boolean;
enable_metadata_passthrough: boolean;
enable_cch_signing: boolean;
+ enable_anthropic_cache_ttl_1h_injection: boolean;
web_search_emulation_enabled?: boolean;
// Payment configuration
@@ -482,6 +486,9 @@ export interface SystemSettings {
// Affiliate (邀请返利) feature switch
affiliate_enabled: boolean;
+
+ // OpenAI fast/flex policy
+ openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
export interface UpdateSettingsRequest {
@@ -495,6 +502,9 @@ export interface UpdateSettingsRequest {
totp_enabled?: boolean; // TOTP 双因素认证
default_balance?: number;
affiliate_rebate_rate?: number;
+ affiliate_rebate_freeze_hours?: number;
+ affiliate_rebate_duration_days?: number;
+ affiliate_rebate_per_invitee_cap?: number;
default_concurrency?: number;
default_user_rpm_limit?: number;
default_subscriptions?: DefaultSubscriptionSetting[];
@@ -602,6 +612,7 @@ export interface UpdateSettingsRequest {
enable_fingerprint_unification?: boolean;
enable_metadata_passthrough?: boolean;
enable_cch_signing?: boolean;
+ enable_anthropic_cache_ttl_1h_injection?: boolean;
// Payment configuration
payment_enabled?: boolean;
payment_min_amount?: number;
@@ -644,6 +655,9 @@ export interface UpdateSettingsRequest {
// Affiliate (邀请返利) feature switch
affiliate_enabled?: boolean;
+
+ // OpenAI fast/flex policy
+ openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
/**
@@ -871,6 +885,29 @@ export async function updateRectifierSettings(
return data;
}
+// ==================== OpenAI Fast Policy Settings ====================
+
+/**
+ * OpenAI fast/flex policy rule interface.
+ * Matches backend dto.OpenAIFastPolicyRule.
+ */
+export interface OpenAIFastPolicyRule {
+ service_tier: "all" | "priority" | "flex";
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
+}
+
+/**
+ * OpenAI fast/flex policy settings interface.
+ */
+export interface OpenAIFastPolicySettings {
+ rules: OpenAIFastPolicyRule[];
+}
+
// ==================== Beta Policy Settings ====================
/**
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index f49f3a1f..bb990fc4 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -564,9 +564,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
- return createPendingLinuxDoOAuthAccount(invitationCode, decision)
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode)
}
/**
@@ -576,27 +577,32 @@ export async function completeLinuxDoOAuthRegistration(
*/
export async function completeOIDCOAuthRegistration(
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
- return createPendingOIDCOAuthAccount(invitationCode, decision)
+ return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode)
}
export async function completeWeChatOAuthRegistration(
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
- return createPendingWeChatOAuthAccount(invitationCode, decision)
+ return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode)
}
async function createPendingOAuthAccount(
provider: 'linuxdo' | 'oidc' | 'wechat',
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
+ const normalizedAffiliateCode = affiliateCode?.trim()
const { data } = await apiClient.post(
`/auth/oauth/${provider}/complete-registration`,
{
invitation_code: invitationCode,
+ ...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}),
...serializeOAuthAdoptionDecision(decision)
}
)
@@ -605,23 +611,26 @@ async function createPendingOAuthAccount(
export async function createPendingLinuxDoOAuthAccount(
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
- return createPendingOAuthAccount('linuxdo', invitationCode, decision)
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode)
}
export async function createPendingOIDCOAuthAccount(
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
- return createPendingOAuthAccount('oidc', invitationCode, decision)
+ return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode)
}
export async function createPendingWeChatOAuthAccount(
invitationCode: string,
- decision?: OAuthAdoptionDecision
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
): Promise {
- return createPendingOAuthAccount('wechat', invitationCode, decision)
+ return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode)
}
export async function completePendingOAuthBindLogin(
diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue
index 41dd1505..90a67922 100644
--- a/frontend/src/components/account/AccountUsageCell.vue
+++ b/frontend/src/components/account/AccountUsageCell.vue
@@ -332,6 +332,37 @@
+
+
+
+ {{ formatKeyRequests }} req
+
+
+ {{ formatKeyTokens }}
+
+
+ A ${{ formatKeyCost }}
+
+
+ U ${{ formatKeyUserCost }}
+
+
+
+
@@ -545,6 +576,10 @@ const shouldFetchUsage = computed(() => {
return false
})
+const showGeminiTodayStats = computed(() => {
+ return props.account.platform === 'gemini' && props.account.type === 'service_account'
+})
+
const geminiUsageAvailable = computed(() => {
return (
!!usageInfo.value?.gemini_shared_daily ||
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 13c30cf9..05016a6d 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -17,7 +17,7 @@
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
- {{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
+ {{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}
@@ -27,7 +27,7 @@
- {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
+ {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}
@@ -227,7 +227,7 @@
@@ -698,6 +698,87 @@
+
+
+
+
+ {{ t('admin.accounts.openai.codexCLIOnly') }}
+
+
+
+
+
+ {{ t('admin.accounts.openai.codexCLIOnlyDesc') }}
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.openai.wsMode') }}
+
+
+
+
+
+ {{ t('admin.accounts.openai.wsModeDesc') }}
+
+
+ {{ t(openAIAPIKeyWSModeConcurrencyHintKey) }}
+
+
+
+
+
@@ -933,6 +1014,13 @@ interface Props {
accountIds: number[]
selectedPlatforms: AccountPlatform[]
selectedTypes: AccountType[]
+ target?: {
+ mode: 'selected' | 'filtered'
+ filters?: Record
+ previewCount?: number
+ selectedPlatforms?: AccountPlatform[]
+ selectedTypes?: AccountType[]
+ }
proxies: ProxyConfig[]
groups: AdminGroup[]
}
@@ -947,40 +1035,53 @@ const { t } = useI18n()
const appStore = useAppStore()
// Platform awareness
-const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
+const targetMode = computed(() => props.target?.mode ?? 'selected')
+const targetPreviewCount = computed(() => props.target?.previewCount ?? props.accountIds.length)
+const targetSelectedPlatforms = computed(() => props.target?.selectedPlatforms ?? props.selectedPlatforms)
+const targetSelectedTypes = computed(() => props.target?.selectedTypes ?? props.selectedTypes)
+const isMixedPlatform = computed(() => targetSelectedPlatforms.value.length > 1)
const allOpenAIPassthroughCapable = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'openai' &&
- props.selectedTypes.length > 0 &&
- props.selectedTypes.every(t => t === 'oauth' || t === 'apikey')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'oauth' || t === 'apikey')
)
})
const allOpenAIOAuth = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'openai' &&
- props.selectedTypes.length > 0 &&
- props.selectedTypes.every(t => t === 'oauth')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'oauth')
+ )
+})
+
+const allOpenAIAPIKey = computed(() => {
+ return (
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'apikey')
)
})
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
const allAnthropicOAuthOrSetupToken = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'anthropic' &&
- props.selectedTypes.every(t => t === 'oauth' || t === 'setup-token')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'anthropic' &&
+ targetSelectedTypes.value.every(t => t === 'oauth' || t === 'setup-token')
)
})
const filteredPresets = computed(() => {
- if (props.selectedPlatforms.length === 0) return []
+ if (targetSelectedPlatforms.value.length === 0) return []
const dedupedPresets = new Map[number]>()
- for (const platform of props.selectedPlatforms) {
+ for (const platform of targetSelectedPlatforms.value) {
for (const preset of getPresetMappingsByPlatform(platform)) {
const key = `${preset.from}=>${preset.to}`
if (!dedupedPresets.has(key)) {
@@ -1012,6 +1113,8 @@ const enableStatus = ref(false)
const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
+const enableOpenAIAPIKeyWSMode = ref(false)
+const enableCodexCLIOnly = ref(false)
const enableRpmLimit = ref(false)
// State - field values
@@ -1035,6 +1138,8 @@ const status = ref<'active' | 'inactive'>('active')
const groupIds = ref([])
const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
+const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
+const codexCLIOnlyEnabled = ref(false)
const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@@ -1076,6 +1181,9 @@ const openAIWSModeOptions = computed(() => [
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
)
+const openAIAPIKeyWSModeConcurrencyHintKey = computed(() =>
+ resolveOpenAIWSModeConcurrencyHintKey(openaiAPIKeyResponsesWebSocketV2Mode.value)
+)
// Model mapping helpers
const addModelMapping = () => {
@@ -1254,6 +1362,19 @@ const buildUpdatePayload = (): Record | null => {
)
}
+ if (enableOpenAIAPIKeyWSMode.value) {
+ const extra = ensureExtra()
+ extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
+ extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
+ openaiAPIKeyResponsesWebSocketV2Mode.value
+ )
+ }
+
+ if (enableCodexCLIOnly.value) {
+ const extra = ensureExtra()
+ extra.codex_cli_only = codexCLIOnlyEnabled.value
+ }
+
// RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) {
const extra = ensureExtra()
@@ -1291,8 +1412,8 @@ const mixedChannelConfirmed = ref(false)
const canPreCheck = () =>
enableGroups.value &&
groupIds.value.length > 0 &&
- props.selectedPlatforms.length === 1 &&
- (props.selectedPlatforms[0] === 'antigravity' || props.selectedPlatforms[0] === 'anthropic')
+ targetSelectedPlatforms.value.length === 1 &&
+ (targetSelectedPlatforms.value[0] === 'antigravity' || targetSelectedPlatforms.value[0] === 'anthropic')
const handleClose = () => {
showMixedChannelWarning.value = false
@@ -1309,7 +1430,7 @@ const preCheckMixedChannelRisk = async (built: Record): Promise
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
- platform: props.selectedPlatforms[0],
+ platform: targetSelectedPlatforms.value[0],
group_ids: groupIds.value
})
if (!result.has_risk) return true
@@ -1325,7 +1446,7 @@ const preCheckMixedChannelRisk = async (built: Record): Promise
}
const handleSubmit = async () => {
- if (props.accountIds.length === 0) {
+ if (targetMode.value === 'selected' && props.accountIds.length === 0) {
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
return
}
@@ -1344,6 +1465,8 @@ const handleSubmit = async () => {
enableStatus.value ||
enableGroups.value ||
enableOpenAIWSMode.value ||
+ enableOpenAIAPIKeyWSMode.value ||
+ enableCodexCLIOnly.value ||
enableRpmLimit.value ||
userMsgQueueMode.value !== null
@@ -1373,7 +1496,12 @@ const submitBulkUpdate = async (baseUpdates: Record) => {
submitting.value = true
try {
- const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
+ const res = targetMode.value === 'filtered' && props.target?.filters
+ ? await adminAPI.accounts.bulkUpdate({
+ filters: props.target.filters,
+ ...updates
+ })
+ : await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
const success = res.success || 0
const failed = res.failed || 0
@@ -1437,6 +1565,8 @@ watch(
enableGroups.value = false
enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false
+ enableOpenAIAPIKeyWSMode.value = false
+ enableCodexCLIOnly.value = false
enableRpmLimit.value = false
// Reset all values
@@ -1456,6 +1586,8 @@ watch(
status.value = 'active'
groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
+ openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
+ codexCLIOnlyEnabled.value = false
rpmLimitEnabled.value = false
bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered'
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 60686225..ca4cf3a7 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -166,7 +166,7 @@
{{ t('admin.accounts.accountType') }}
-
+
+
+
+
+
+
+ Vertex
+ Service Account
+
+
+
+
+
+
+
{{ t('admin.accounts.vertexAnthropicHint') }}
@@ -315,6 +348,7 @@
{{ t('admin.accounts.types.responsesApi') }}
+
@@ -333,7 +367,7 @@
{{ t('admin.accounts.gemini.helpButton') }}
-
+
+
+
+
+
+
+
+
+ Vertex
+
+
+ Service Account
+
+
+
+
+
{{ t('admin.accounts.vertexGeminiHint') }}
+
+
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
@@ -623,7 +694,7 @@
-
+
{{ t('admin.accounts.gemini.tier.label') }}
+
+
+
+
Service Account JSON
+
+
+
+
+
+
+ {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}
+
+
+ {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}
+
+
+
+
+ {{ t('admin.accounts.vertexSaJsonSelectBtn') }}
+
+
+
+
Project ID: {{ vertexProjectId }}
+
Client Email: {{ vertexClientEmail }}
+
+
+
{{ t('admin.accounts.vertexSaJsonUploadHint') }}
+
+
+
+
+ Project ID
+
+
+
+
Location
+
+
+
+ {{ option.label }}
+
+
+
+
{{ t('admin.accounts.vertexLocationHint') }}
+
+
+
+
@@ -3119,6 +3280,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
+import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import {
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
@@ -3233,7 +3395,7 @@ interface TempUnschedRuleForm {
// State
const step = ref(1)
const submitting = ref(false)
-const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
+const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category
const addMethod = ref
('oauth') // For oauth-based: 'oauth' or 'setup-token'
const apiKeyBaseUrl = ref('https://api.anthropic.com')
const apiKeyValue = ref('')
@@ -3306,6 +3468,12 @@ const bedrockSessionToken = ref('')
const bedrockRegion = ref('us-east-1')
const bedrockForceGlobal = ref(false)
const bedrockApiKeyValue = ref('')
+const vertexServiceAccountFileInput = ref(null)
+const vertexServiceAccountJson = ref('')
+const vertexProjectId = ref('')
+const vertexClientEmail = ref('')
+const vertexLocation = ref('global')
+const vertexServiceAccountDragActive = ref(false)
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping')
@@ -3556,7 +3724,7 @@ watch(
// Sync form.type based on accountCategory, addMethod, and platform-specific type
watch(
- [accountCategory, addMethod, antigravityAccountType],
+ [accountCategory, addMethod, antigravityAccountType, () => form.platform],
([category, method, agType]) => {
// Antigravity upstream 类型(实际创建为 apikey)
if (form.platform === 'antigravity' && agType === 'upstream') {
@@ -3568,7 +3736,9 @@ watch(
form.type = 'bedrock' as AccountType
return
}
- if (category === 'oauth-based') {
+ if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') {
+ form.type = 'service_account' as AccountType
+ } else if (category === 'oauth-based') {
form.type = method as AccountType // 'oauth' or 'setup-token'
} else {
form.type = 'apikey'
@@ -3606,6 +3776,12 @@ watch(
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
}
+ if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
+ accountCategory.value = 'oauth-based'
+ }
+ if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') {
+ accountCategory.value = 'oauth-based'
+ }
// Reset Bedrock fields when switching platforms
bedrockAccessKeyId.value = ''
bedrockSecretAccessKey.value = ''
@@ -3614,6 +3790,10 @@ watch(
bedrockForceGlobal.value = false
bedrockAuthMode.value = 'sigv4'
bedrockApiKeyValue.value = ''
+ vertexServiceAccountJson.value = ''
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ vertexLocation.value = 'global'
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false
@@ -4068,6 +4248,10 @@ const resetForm = () => {
upstreamType.value = 'sub2api'
upstreamBaseUrl.value = ''
upstreamApiKey.value = ''
+ vertexServiceAccountJson.value = ''
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ vertexLocation.value = 'global'
tempUnschedEnabled.value = false
tempUnschedRules.value = []
geminiOAuthType.value = 'code_assist'
@@ -4195,6 +4379,52 @@ const normalizePoolModeRetryCount = (value: number) => {
return normalized
}
+const applyVertexServiceAccountJson = (value: string) => {
+ const raw = value.trim()
+ if (!raw) {
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ return false
+ }
+ try {
+ const parsed = JSON.parse(raw) as Record
+ const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : ''
+ const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : ''
+ const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : ''
+ if (!projectId || !clientEmail || !privateKey) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingFields'))
+ return false
+ }
+ vertexProjectId.value = projectId
+ vertexClientEmail.value = clientEmail
+ vertexServiceAccountJson.value = JSON.stringify(parsed)
+ return true
+ } catch {
+ appStore.showError(t('admin.accounts.vertexSaJsonInvalid'))
+ return false
+ }
+}
+
+const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value)
+
+const handleVertexServiceAccountFile = async (event: Event) => {
+ const input = event.target as HTMLInputElement
+ const file = input.files?.[0]
+ if (!file) return
+ try {
+ applyVertexServiceAccountJson(await file.text())
+ } finally {
+ input.value = ''
+ }
+}
+
+const handleVertexServiceAccountDrop = async (event: DragEvent) => {
+ vertexServiceAccountDragActive.value = false
+ const file = event.dataTransfer?.files?.[0]
+ if (!file) return
+ applyVertexServiceAccountJson(await file.text())
+}
+
const handleSubmit = async () => {
// For OAuth-based type, handle OAuth flow (goes to step 2)
if (isOAuthFlow.value) {
@@ -4348,6 +4578,29 @@ const handleSubmit = async () => {
return
}
+ if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') {
+ if (!form.name.trim()) {
+ appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
+ return
+ }
+ if (!parseVertexServiceAccountJson()) {
+ return
+ }
+ if (!vertexLocation.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexLocationRequired'))
+ return
+ }
+ const credentials: Record = {
+ service_account_json: vertexServiceAccountJson.value.trim(),
+ project_id: vertexProjectId.value.trim(),
+ client_email: vertexClientEmail.value.trim(),
+ location: vertexLocation.value.trim(),
+ tier_id: 'vertex'
+ }
+ await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials)
+ return
+ }
+
// For apikey type, create directly
if (!apiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 85f05395..b117bff3 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -599,6 +599,221 @@
+
+
+
+
+
Project ID
+
+
{{ t('admin.accounts.vertexSaJsonEditHint') }}
+
+
+
Location
+
+
+
+ {{ option.label }}
+
+
+
+
{{ t('admin.accounts.vertexLocationHint') }}
+
+
+
+
+
+
{{ t('admin.accounts.modelRestriction') }}
+
+
+
+
+
+
+
+ {{ t('admin.accounts.modelWhitelist') }}
+
+
+
+
+
+ {{ t('admin.accounts.modelMapping') }}
+
+
+
+
+
+
+
+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
+ {{
+ t('admin.accounts.supportsAllModels')
+ }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.mapRequestModels') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.addMapping') }}
+
+
+
+
+
+ + {{ preset.label }}
+
+
+
+
+
+
@@ -1951,6 +2166,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
+import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import {
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
@@ -2020,6 +2236,9 @@ const editBedrockSessionToken = ref('')
const editBedrockRegion = ref('')
const editBedrockForceGlobal = ref(false)
const editBedrockApiKeyValue = ref('')
+const editVertexProjectId = ref('')
+const editVertexClientEmail = ref('')
+const editVertexLocation = ref('us-central1')
const isBedrockAPIKeyMode = computed(() =>
props.account?.type === 'bedrock' &&
(props.account?.credentials as Record
)?.auth_mode === 'apikey'
@@ -2279,6 +2498,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
const credentials = newAccount.credentials as Record | undefined
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
+ editVertexProjectId.value = ''
+ editVertexClientEmail.value = ''
+ editVertexLocation.value = 'us-central1'
// Load mixed scheduling setting (only for antigravity accounts)
mixedScheduling.value = false
@@ -2504,6 +2726,31 @@ const syncFormFromAccount = (newAccount: Account | null) => {
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record
editBaseUrl.value = (credentials.base_url as string) || ''
+ } else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) {
+ const credentials = newAccount.credentials as Record
+ editVertexProjectId.value = (credentials.project_id as string) || ''
+ editVertexClientEmail.value = (credentials.client_email as string) || ''
+ editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
+
+ // Load model mappings for service_account
+ const existingMappings = credentials.model_mapping as Record | undefined
+ if (existingMappings && typeof existingMappings === 'object') {
+ const entries = Object.entries(existingMappings)
+ const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
+ if (isWhitelistMode) {
+ modelRestrictionMode.value = 'whitelist'
+ allowedModels.value = entries.map(([from]) => from)
+ modelMappings.value = []
+ } else {
+ modelRestrictionMode.value = 'mapping'
+ modelMappings.value = entries.map(([from, to]) => ({ from, to }))
+ allowedModels.value = []
+ }
+ } else {
+ modelRestrictionMode.value = 'whitelist'
+ modelMappings.value = []
+ allowedModels.value = []
+ }
} else {
const platformDefaultUrl =
newAccount.platform === 'openai'
@@ -3099,6 +3346,46 @@ const handleSubmit = async () => {
return
}
+ updatePayload.credentials = newCredentials
+ } else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') {
+ const currentCredentials = (props.account.credentials as Record) || {}
+ const newCredentials: Record = { ...currentCredentials }
+
+ if (!editVertexProjectId.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId'))
+ return
+ }
+ if (!editVertexClientEmail.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail'))
+ return
+ }
+ if (!editVertexLocation.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexLocationRequired'))
+ return
+ }
+
+ if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
+ appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
+ return
+ }
+ newCredentials.project_id = editVertexProjectId.value.trim()
+ newCredentials.client_email = editVertexClientEmail.value.trim()
+ newCredentials.location = editVertexLocation.value.trim()
+ newCredentials.tier_id = 'vertex'
+
+ // Add model mapping if configured
+ const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ if (modelMapping) {
+ newCredentials.model_mapping = modelMapping
+ } else {
+ delete newCredentials.model_mapping
+ }
+
+ applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
+ if (!applyTempUnschedConfig(newCredentials)) {
+ return
+ }
+
updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock') {
const currentCredentials = (props.account.credentials as Record) || {}
diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
index 9158da64..fa4104f6 100644
--- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
+++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
@@ -57,6 +57,19 @@ function makeAccount(overrides: Partial): Account {
describe('AccountUsageCell', () => {
beforeEach(() => {
getUsage.mockReset()
+ Object.defineProperty(window, 'matchMedia', {
+ writable: true,
+ value: vi.fn().mockImplementation(() => ({
+ matches: true,
+ media: '(min-width: 768px)',
+ onchange: null,
+ addListener: vi.fn(),
+ removeListener: vi.fn(),
+ addEventListener: vi.fn(),
+ removeEventListener: vi.fn(),
+ dispatchEvent: vi.fn(),
+ }))
+ })
})
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
@@ -603,4 +616,43 @@ describe('AccountUsageCell', () => {
expect(wrapper.text().trim()).toBe('-')
})
+
+ it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => {
+ const wrapper = mount(AccountUsageCell, {
+ props: {
+ account: makeAccount({
+ id: 4001,
+ platform: 'gemini',
+ type: 'service_account',
+ credentials: {
+ tier_id: 'vertex',
+ project_id: 'vertex-proj',
+ client_email: 'svc@vertex-proj.iam.gserviceaccount.com',
+ location: 'global'
+ },
+ extra: {}
+ }),
+ todayStats: {
+ requests: 0,
+ tokens: 0,
+ cost: 0,
+ standard_cost: 0,
+ user_cost: 0
+ }
+ },
+ global: {
+ stubs: {
+ UsageProgressBar: true,
+ AccountQuotaInfo: true
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('0 req')
+ expect(wrapper.text()).toContain('0')
+ expect(wrapper.text()).toContain('A $0.00')
+ expect(wrapper.text()).toContain('U $0.00')
+ })
})
diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
index 7390e723..50d170da 100644
--- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
+++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
@@ -178,6 +178,45 @@ describe('BulkEditAccountModal', () => {
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
})
+ it('OpenAI OAuth 批量编辑应提交 codex_cli_only 字段', async () => {
+ const wrapper = mountModal({
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['oauth']
+ })
+
+ await wrapper.get('#bulk-edit-openai-codex-cli-only-enabled').setValue(true)
+ await wrapper.get('#bulk-edit-openai-codex-cli-only-toggle').trigger('click')
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
+ extra: {
+ codex_cli_only: true
+ }
+ })
+ })
+
+ it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
+ const wrapper = mountModal({
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['apikey']
+ })
+
+ await wrapper.get('#bulk-edit-openai-apikey-ws-mode-enabled').setValue(true)
+ await wrapper.get('[data-testid="bulk-edit-openai-apikey-ws-mode-select"]').setValue('ctx_pool')
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
+ extra: {
+ openai_apikey_responses_websockets_v2_mode: 'ctx_pool',
+ openai_apikey_responses_websockets_v2_enabled: true
+ }
+ })
+ })
+
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
@@ -217,4 +256,41 @@ describe('BulkEditAccountModal', () => {
})
expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough')
})
+
+ it('filtered-results 模式下应提交 filters 而不是 account_ids', async () => {
+ const wrapper = mountModal({
+ accountIds: [],
+ target: {
+ mode: 'filtered',
+ filters: {
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ group: '12',
+ search: 'bulk-target',
+ privacy_mode: 'training_set_cf_blocked'
+ },
+ previewCount: 5,
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['oauth']
+ }
+ })
+
+ await wrapper.get('#bulk-edit-status-enabled').setValue(true)
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
+ filters: {
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ group: '12',
+ search: 'bulk-target',
+ privacy_mode: 'training_set_cf_blocked'
+ },
+ status: 'active'
+ })
+ })
})
diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue
index 3b987bd0..a632bdd4 100644
--- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue
+++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue
@@ -1,9 +1,13 @@
-
+
-
+
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
+
+ {{ t('admin.accounts.bulkEdit.title') }}
+
+
{{ t('admin.accounts.bulkActions.clear') }}
+
- {{ t('admin.accounts.bulkActions.delete') }}
- {{ t('admin.accounts.bulkActions.resetStatus') }}
- {{ t('admin.accounts.bulkActions.refreshToken') }}
- {{ t('admin.accounts.bulkActions.enableScheduling') }}
- {{ t('admin.accounts.bulkActions.disableScheduling') }}
- {{ t('admin.accounts.bulkActions.edit') }}
+
+ {{ t('admin.accounts.bulkActions.delete') }}
+ {{ t('admin.accounts.bulkActions.resetStatus') }}
+ {{ t('admin.accounts.bulkActions.refreshToken') }}
+ {{ t('admin.accounts.bulkActions.enableScheduling') }}
+ {{ t('admin.accounts.bulkActions.disableScheduling') }}
+ {{ t('admin.accounts.bulkActions.edit') }}
+
+
+ {{ t('admin.accounts.bulkEdit.submit') }}
+
diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue
index c740d06f..6b245123 100644
--- a/frontend/src/components/auth/LinuxDoOAuthSection.vue
+++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue
@@ -42,9 +42,11 @@