diff --git a/.gitignore b/.gitignore index bf7ee064..cf251f07 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ docs/claude-relay-service/ +.codex # =================== # Go 后端 @@ -121,7 +122,7 @@ scripts .code-review-state #openspec/ code-reviews/ -#AGENTS.md +AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ diff --git a/README.md b/README.md index 3e609d65..718730c6 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control. + +pateway +Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line. +Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information. +Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150. + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index add32a17..24600e0e 100644 --- a/README_CN.md +++ b/README_CN.md @@ -100,6 +100,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。 + +pateway +感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。 +同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。 +现在通过 此链接 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。 + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index ccd595b9..1e89610c 100644 --- a/README_JA.md +++ b/README_JA.md @@ -100,6 +100,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。 + +pateway +PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。 +エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。 +こちらのリンクから登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。 + + ## エコシステム diff --git a/assets/partners/logos/pateway.png b/assets/partners/logos/pateway.png new file mode 100644 index 00000000..7ca3489a Binary files /dev/null and b/assets/partners/logos/pateway.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 1fcba8fa..025c3166 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.118 +0.1.121 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 6aaaa0c4..ff385516 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userGroupRateRepository := repository.NewUserGroupRateRepository(db) billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) apiKeyCache := repository.NewAPIKeyCache(redisClient) - apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) + apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) @@ -143,6 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) + claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) @@ -155,7 +156,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { windsurfTokenProvider := service.ProvideWindsurfTokenProvider(configConfig, accountRepository, proxyRepository) windsurfChatService := service.ProvideWindsurfChatService(configConfig, windsurfLSService, windsurfTokenProvider, gatewayCache) windsurfGatewayService := service.ProvideWindsurfGatewayService(configConfig, windsurfChatService, accountRepository) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, windsurfChatService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) @@ -182,7 +183,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) @@ -191,7 +191,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { rpmTokenBucketService := service.NewRPMTokenBucketService() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 1b0a79ec..d52f5342 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -33,6 +33,7 @@ const ( AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) AccountTypeWindsurfSession = "windsurf-session" // Windsurf Session 类型账号(邮箱密码登录获取的 session token + api_key) + AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI) ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 4dcfaa6b..c107c329 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -99,7 +99,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -118,7 +118,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -135,19 +135,29 @@ type UpdateAccountRequest struct { // BulkUpdateAccountsRequest represents the payload for bulk editing accounts type BulkUpdateAccountsRequest struct { - AccountIDs []int64 `json:"account_ids" binding:"required,min=1"` - Name string `json:"name"` - ProxyID *int64 `json:"proxy_id"` - Concurrency *int `json:"concurrency"` - Priority *int `json:"priority"` - RateMultiplier *float64 `json:"rate_multiplier"` - LoadFactor *int `json:"load_factor"` - Status string `json:"status" binding:"omitempty,oneof=active inactive error"` - Schedulable *bool `json:"schedulable"` - GroupIDs *[]int64 `json:"group_ids"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra"` - ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 + AccountIDs []int64 `json:"account_ids"` + Filters *BulkUpdateAccountFilters `json:"filters"` + Name string `json:"name"` + ProxyID *int64 `json:"proxy_id"` + Concurrency *int `json:"concurrency"` + Priority *int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` + Status string `json:"status" binding:"omitempty,oneof=active inactive error"` + Schedulable *bool `json:"schedulable"` + GroupIDs *[]int64 `json:"group_ids"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 +} + +type BulkUpdateAccountFilters struct { + Platform string `json:"platform"` + Type string `json:"type"` + Status string `json:"status"` + Group string `json:"group"` + Search string `json:"search"` + PrivacyMode string `json:"privacy_mode"` } // CheckMixedChannelRequest represents check mixed channel risk request @@ -1370,6 +1380,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + if len(req.AccountIDs) == 0 && req.Filters == nil { + response.BadRequest(c, "account_ids or filters is required") + return + } // base_rpm 输入校验:负值归零,超过 10000 截断 sanitizeExtraBaseRPM(req.Extra) @@ -1395,6 +1409,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{ AccountIDs: req.AccountIDs, + Filters: toServiceBulkUpdateAccountFilters(req.Filters), Name: req.Name, ProxyID: req.ProxyID, Concurrency: req.Concurrency, @@ -1430,6 +1445,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.Success(c, result) } +func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters { + if filters == nil { + return nil + } + return &service.BulkUpdateAccountFilters{ + Platform: filters.Platform, + Type: filters.Type, + Status: filters.Status, + Group: filters.Group, + Search: filters.Search, + PrivacyMode: filters.PrivacyMode, + } +} + // ========== OAuth Handlers ========== // GenerateAuthURLRequest represents the request for generating auth URL diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index 24ec5bcf..929dc240 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) { require.Equal(t, float64(2), data["success"]) require.Equal(t, float64(0), data["failed"]) } + +func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "filters": map[string]any{ + "platform": "openai", + "type": "oauth", + "status": "active", + "group": "12", + "privacy_mode": "blocked", + "search": "bulk-target", + }, + "schedulable": true, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) +} diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go index 3833d32e..6df49154 100644 --- a/backend/internal/handler/admin/admin_helpers_test.go +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) { require.True(t, isAddrInTrustedProxies(addr, prefixes)) require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes)) } + +// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin +// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为 +// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。 +func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + require.Nil(t, openaiFastPolicySettingsFromDTO(nil)) + }) + + t.Run("empty service_tier becomes 'all'", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: "", + Action: "filter", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.NotNil(t, out) + require.Len(t, out.Rules, 1) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier) + require.Equal(t, "all", out.Rules[0].ServiceTier) + }) + + t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: " ", + Action: "pass", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier) + }) + + t.Run("uppercase service_tier is lowercased", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: "PRIORITY", + Action: "filter", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier) + }) + + t.Run("non-empty values pass through (lowercased)", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{ + {ServiceTier: "priority", Action: "filter", Scope: "all"}, + {ServiceTier: "flex", Action: "block", Scope: "oauth"}, + {ServiceTier: "all", Action: "pass", Scope: "apikey"}, + }, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Len(t, out.Rules, 3) + require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier) + require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier) + }) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 2fe29fa3..b187b47f 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -565,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return nil, service.ErrAPIKeyNotFound } +func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) { + for i := range s.apiKeys { + if s.apiKeys[i].ID == keyID { + s.apiKeys[i].Usage5h = 0 + s.apiKeys[i].Usage1d = 0 + s.apiKeys[i].Usage7d = 0 + s.apiKeys[i].Window5hStart = nil + s.apiKeys[i].Window1dStart = nil + s.apiKeys[i].Window7dStart = nil + k := s.apiKeys[i] + return &k, nil + } + } + return nil, service.ErrAPIKeyNotFound +} + func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error { return nil } diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go index 8dd245a4..5e405bdd 100644 --- a/backend/internal/handler/admin/apikey_handler.go +++ b/backend/internal/handler/admin/apikey_handler.go @@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle } } -// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group +// AdminUpdateAPIKeyGroupRequest represents the request to update an API key. type AdminUpdateAPIKeyGroupRequest struct { - GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 + GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量 } -// UpdateGroup handles updating an API key's group binding +// UpdateGroup handles updating an API key's admin-managed fields. // PUT /api/v1/admin/api-keys/:id func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { keyID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { return } + var resetKey *service.APIKey + if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage { + resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID) if err != nil { response.ErrorFrom(c, err) return } + if resetKey != nil && req.GroupID == nil { + result.APIKey = resetKey + } resp := struct { APIKey *dto.APIKey `json:"api_key"` diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go index bf128b18..6ac6d52f 100644 --- a/backend/internal/handler/admin/apikey_handler_test.go +++ b/backend/internal/handler/admin/apikey_handler_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/service" @@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) { require.Nil(t, resp.Data.APIKey.GroupID) } +func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) { + svc := newStubAdminService() + now := time.Now() + svc.apiKeys[0].Usage5h = 1.2 + svc.apiKeys[0].Usage1d = 3.4 + svc.apiKeys[0].Usage7d = 5.6 + svc.apiKeys[0].Window5hStart = &now + svc.apiKeys[0].Window1dStart = &now + svc.apiKeys[0].Window7dStart = &now + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data struct { + APIKey struct { + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5hStart *time.Time `json:"window_5h_start"` + Window1dStart *time.Time `json:"window_1d_start"` + Window7dStart *time.Time `json:"window_7d_start"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Zero(t, resp.Data.APIKey.Usage5h) + require.Zero(t, resp.Data.APIKey.Usage1d) + require.Zero(t, resp.Data.APIKey.Usage7d) + require.Nil(t, resp.Data.APIKey.Window5hStart) + require.Nil(t, resp.Data.APIKey.Window1dStart) + require.Nil(t, resp.Data.APIKey.Window7dStart) +} + func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) { svc := &failingUpdateGroupService{ stubAdminService: newStubAdminService(), diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 40bf1c69..59f4fe85 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -186,6 +186,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, AffiliateRebateRate: settings.AffiliateRebateRate, + AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours, + AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap, DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, @@ -206,6 +209,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, + EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, @@ -245,9 +249,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { AffiliateEnabled: settings.AffiliateEnabled, } + + // OpenAI fast policy (stored under a dedicated setting key) + if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { + slog.Error("openai_fast_policy_settings_get_failed", "error", err) + } else if fastPolicy != nil { + payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } +// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy. +func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings { + if s == nil { + return nil + } + rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules)) + for i, r := range s.Rules { + rules[i] = dto.OpenAIFastPolicyRule(r) + } + return &dto.OpenAIFastPolicySettings{Rules: rules} +} + +// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy. +// +// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为 +// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时 +// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由 +// service.SetOpenAIFastPolicySettings 负责合法值校验。 +func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings { + if s == nil { + return nil + } + rules := make([]service.OpenAIFastPolicyRule, len(s.Rules)) + for i, r := range s.Rules { + rules[i] = service.OpenAIFastPolicyRule(r) + tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier)) + if tier == "" { + tier = service.OpenAIFastTierAny + } + rules[i].ServiceTier = tier + } + return &service.OpenAIFastPolicySettings{Rules: rules} +} + // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 @@ -342,6 +388,9 @@ type UpdateSettingsRequest struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` + AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"` + AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"` + AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"` DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` @@ -393,9 +442,10 @@ type UpdateSettingsRequest struct { BackendModeEnabled bool `json:"backend_mode_enabled"` // Gateway forwarding behavior - EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` - EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` - EnableCCHSigning *bool `json:"enable_cch_signing"` + EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` + EnableCCHSigning *bool `json:"enable_cch_signing"` + EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"` // Payment visible method routing PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` @@ -446,6 +496,9 @@ type UpdateSettingsRequest struct { // Affiliate (邀请返利) feature switch AffiliateEnabled *bool `json:"affiliate_enabled"` + + // OpenAI fast/flex policy (optional, only updated when provided) + OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } // UpdateSettings 更新系统设置 @@ -485,6 +538,33 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if affiliateRebateRate > service.AffiliateRebateRateMax { affiliateRebateRate = service.AffiliateRebateRateMax } + affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours + if req.AffiliateRebateFreezeHours != nil { + affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours + } + if affiliateRebateFreezeHours < 0 { + affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault + } + if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax { + affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax + } + affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays + if req.AffiliateRebateDurationDays != nil { + affiliateRebateDurationDays = *req.AffiliateRebateDurationDays + } + if affiliateRebateDurationDays < 0 { + affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault + } + if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax { + affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax + } + affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap + if req.AffiliateRebatePerInviteeCap != nil { + affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap + } + if affiliateRebatePerInviteeCap < 0 { + affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault + } // 通用表格配置:兼容旧客户端未传字段时保留当前值。 if req.TableDefaultPageSize <= 0 { req.TableDefaultPageSize = previousSettings.TableDefaultPageSize @@ -1137,6 +1217,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, AffiliateRebateRate: affiliateRebateRate, + AffiliateRebateFreezeHours: affiliateRebateFreezeHours, + AffiliateRebateDurationDays: affiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap, DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, @@ -1192,6 +1275,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + EnableAnthropicCacheTTL1hInjection: func() bool { + if req.EnableAnthropicCacheTTL1hInjection != nil { + return *req.EnableAnthropicCacheTTL1hInjection + } + return previousSettings.EnableAnthropicCacheTTL1hInjection + }(), PaymentVisibleMethodAlipaySource: func() string { if req.PaymentVisibleMethodAlipaySource != nil { return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) @@ -1314,6 +1403,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } + // Update OpenAI fast policy (stored under dedicated key, only when provided). + if req.OpenAIFastPolicySettings != nil { + if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil { + response.BadRequest(c, err.Error()) + return + } + } + // Update payment configuration (integrated into system settings). // Skip if no payment fields were provided (prevents accidental wipe). if h.paymentConfigService != nil && hasPaymentFields(req) { @@ -1458,6 +1555,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, AffiliateRebateRate: updatedSettings.AffiliateRebateRate, + AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours, + AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap, DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -1478,6 +1578,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection, PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, @@ -1516,6 +1617,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { AffiliateEnabled: updatedSettings.AffiliateEnabled, } + if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { + slog.Error("openai_fast_policy_settings_get_failed", "error", err) + } else if fastPolicy != nil { + payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) + } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } @@ -1768,6 +1874,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.AffiliateRebateRate != after.AffiliateRebateRate { changed = append(changed, "affiliate_rebate_rate") } + if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours { + changed = append(changed, "affiliate_rebate_freeze_hours") + } + if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays { + changed = append(changed, "affiliate_rebate_duration_days") + } + if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap { + changed = append(changed, "affiliate_rebate_per_invitee_cap") + } if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { changed = append(changed, "default_subscriptions") } @@ -1843,6 +1958,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection { + changed = append(changed, "enable_anthropic_cache_ttl_1h_injection") + } if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { changed = append(changed, "payment_visible_method_alipay_source") } diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index 9a33a93a..085fd2ca 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service. } func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { - panic("unexpected GetValue call") + if s.values != nil { + if value, ok := s.values[key]; ok { + return value, nil + } + } + return "", nil } func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2ef05963..7df4abfd 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -435,6 +435,7 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( type completeLinuxDoOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -518,7 +519,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 604ad903..490afd0f 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -67,6 +67,7 @@ type createPendingOAuthAccountRequest struct { VerifyCode string `json:"verify_code,omitempty"` Password string `json:"password" binding:"required,min=6"` InvitationCode string `json:"invitation_code,omitempty"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -1751,6 +1752,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) user, strings.TrimSpace(req.InvitationCode), strings.TrimSpace(session.ProviderType), + strings.TrimSpace(req.AffCode), ); err != nil { _ = tx.Rollback() if rollbackCreatedUser(err) { diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0ac8871b..4264002d 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -582,6 +582,7 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( type completeOIDCOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -665,7 +666,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index efee4cc0..34e70ed0 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -481,6 +481,7 @@ func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService type completeWeChatOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } @@ -547,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 051fab18..492be170 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -106,11 +106,14 @@ type SystemSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` - DefaultUserRPMLimit int `json:"default_user_rpm_limit"` - DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate float64 `json:"affiliate_rebate_rate"` + AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"` + AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"` + AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` + DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -139,9 +142,10 @@ type SystemSettings struct { BackendModeEnabled bool `json:"backend_mode_enabled"` // Gateway forwarding behavior - EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` - EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` - EnableCCHSigning bool `json:"enable_cch_signing"` + EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` + EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` + EnableCCHSigning bool `json:"enable_cch_signing"` + EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"` // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` @@ -195,6 +199,9 @@ type SystemSettings struct { // Affiliate (邀请返利) feature switch AffiliateEnabled bool `json:"affiliate_enabled"` + + // OpenAI fast/flex policy + OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } type DefaultSubscriptionSetting struct { @@ -291,6 +298,22 @@ type BetaPolicySettings struct { Rules []BetaPolicyRule `json:"rules"` } +// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO +type OpenAIFastPolicyRule struct { + ServiceTier string `json:"service_tier"` + Action string `json:"action"` + Scope string `json:"scope"` + ErrorMessage string `json:"error_message,omitempty"` + ModelWhitelist []string `json:"model_whitelist,omitempty"` + FallbackAction string `json:"fallback_action,omitempty"` + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` +} + +// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO +type OpenAIFastPolicySettings struct { + Rules []OpenAIFastPolicyRule `json:"rules"` +} + // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // Returns empty slice on empty/invalid input. func ParseCustomMenuItems(raw string) []CustomMenuItem { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7bb9c46d..d12c2941 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -288,6 +288,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + // [DEBUG-STICKY] 打印会话 hash 生成结果 + reqLog.Info("sticky.session_hash_generated", + zap.String("session_hash", sessionHash), + zap.String("metadata_user_id_raw", parsedReq.MetadataUserID), + ) + // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 platform := "" if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { @@ -304,6 +310,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + // [DEBUG-STICKY] 打印粘性会话查询结果 + reqLog.Info("sticky.cache_lookup", + zap.String("session_key", sessionKey), + zap.Int64("bound_account_id", sessionBoundAccountID), + ) if sessionBoundAccountID > 0 { prefetchedGroupID := int64(0) if apiKey.GroupID != nil { @@ -312,6 +323,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } + } else { + reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash)) } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 @@ -591,6 +604,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 + reqLog.Info("sticky.selecting_account", + zap.String("session_key", sessionKey), + zap.Int64("sticky_bound_account_id", sessionBoundAccountID), + zap.Bool("has_bound_session", hasBoundSession), + zap.Int("failed_account_count", len(fs.FailedAccountIDs)), + ) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { @@ -624,6 +643,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) + // [DEBUG-STICKY] 打印账号选择结果 + reqLog.Info("sticky.account_selected", + zap.Int64("selected_account_id", account.ID), + zap.String("account_name", account.Name), + zap.Bool("slot_acquired", selection.Acquired), + zap.Bool("has_wait_plan", selection.WaitPlan != nil), + zap.Int64("sticky_bound_account_id", sessionBoundAccountID), + zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID), + ) + // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) @@ -690,6 +719,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // Slot acquired: no longer waiting in queue. releaseWait() + reqLog.Info("sticky.bind_after_wait", + zap.String("session_key", sessionKey), + zap.Int64("account_id", account.ID), + ) if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } @@ -924,6 +957,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } + // 绑定粘性会话(成功转发后绑定/刷新) + // - 无现有绑定(首次请求):创建绑定 + // - 选中账号与粘性账号一致:刷新 TTL + // - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定, + // 下次请求粘性账号恢复后仍可命中 + if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) { + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 71030140..57554cf9 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time. func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { return true, nil } +func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error { + return nil +} func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 403b41ef..4d0078a7 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { return } - sessionHash := "" - if parsed.Multipart { - sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed()) - } else { - sessionHash = h.gatewayService.GenerateSessionHash(c, body) - } + sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index 37bd38b2..e7d8aab9 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -25,6 +25,7 @@ const ( easypayStatusPaid = 1 easypayHTTPTimeout = 10 * time.Second maxEasypayResponseSize = 1 << 20 // 1MB + maxEasypayErrorSummary = 512 tradeStatusSuccess = "TRADE_SUCCESS" signTypeMD5 = "MD5" paymentModePopup = "popup" @@ -42,17 +43,55 @@ type EasyPay struct { // config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) { for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} { - if config[k] == "" { + if strings.TrimSpace(config[k]) == "" { return nil, fmt.Errorf("easypay config missing required key: %s", k) } } + cfg := make(map[string]string, len(config)) + for k, v := range config { + cfg[k] = v + } + cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"]) return &EasyPay{ instanceID: instanceID, - config: config, + config: cfg, httpClient: &http.Client{Timeout: easypayHTTPTimeout}, }, nil } +func normalizeEasyPayAPIBase(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return "" + } + if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" { + parsed.RawQuery = "" + parsed.Fragment = "" + parsed.RawPath = "" + parsed.Path = trimEasyPayEndpointPath(parsed.Path) + return strings.TrimRight(parsed.String(), "/") + } + return strings.TrimRight(trimEasyPayEndpointPath(base), "/") +} + +func trimEasyPayEndpointPath(path string) string { + path = strings.TrimRight(strings.TrimSpace(path), "/") + lower := strings.ToLower(path) + for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} { + if strings.HasSuffix(lower, endpoint) { + return strings.TrimRight(path[:len(path)-len(endpoint)], "/") + } + } + return path +} + +func (e *EasyPay) apiBase() string { + if e == nil { + return "" + } + return normalizeEasyPayAPIBase(e.config["apiBase"]) +} + func (e *EasyPay) Name() string { return "EasyPay" } func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay } func (e *EasyPay) SupportedTypes() []payment.PaymentType { @@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym for k, v := range params { q.Set(k, v) } - base := strings.TrimRight(e.config["apiBase"], "/") - payURL := base + "/submit.php?" + q.Encode() + payURL := e.apiBase() + "/submit.php?" + q.Encode() return &payment.CreatePaymentResponse{PayURL: payURL}, nil } @@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen params["sign"] = easyPaySign(params, e.config["pkey"]) params["sign_type"] = signTypeMD5 - body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params) + body, err := e.post(ctx, e.apiBase()+"/mapi.php", params) if err != nil { return nil, fmt.Errorf("easypay create: %w", err) } @@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer "act": "order", "pid": e.config["pid"], "key": e.config["pkey"], "out_trade_no": tradeNo, } - body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params) + body, err := e.post(ctx, e.apiBase()+"/api.php", params) if err != nil { return nil, fmt.Errorf("easypay query: %w", err) } @@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st } func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { - params := map[string]string{ - "pid": e.config["pid"], "key": e.config["pkey"], - "trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount, + attempts := e.refundAttempts(req) + if len(attempts) == 0 { + return nil, fmt.Errorf("easypay refund missing order identifier") } - body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params) - if err != nil { - return nil, fmt.Errorf("easypay refund: %w", err) + var firstErr error + for i, attempt := range attempts { + body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params) + if err != nil { + return nil, fmt.Errorf("easypay refund request: %w", err) + } + if err := parseEasyPayRefundResponse(status, body); err != nil { + if firstErr == nil { + firstErr = err + } + if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) { + continue + } + return nil, err + } + return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil } + return nil, firstErr +} + +type easyPayRefundAttempt struct { + params map[string]string + refundID string +} + +func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt { + base := map[string]string{ + "pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount, + } + var attempts []easyPayRefundAttempt + if orderID := strings.TrimSpace(req.OrderID); orderID != "" { + params := cloneStringMap(base) + params["out_trade_no"] = orderID + attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID}) + } + if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" { + params := cloneStringMap(base) + params["trade_no"] = tradeNo + attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo}) + } + return attempts +} + +func cloneStringMap(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func isEasyPayRefundOrderNotFound(err error) bool { + if err == nil { + return false + } + msg := err.Error() + lower := strings.ToLower(msg) + return strings.Contains(msg, "订单编号不存在") || + strings.Contains(msg, "订单不存在") || + strings.Contains(lower, "order not found") || + strings.Contains(lower, "not exist") +} + +func parseEasyPayRefundResponse(status int, body []byte) error { + summary := summarizeEasyPayResponse(body) + if status < http.StatusOK || status >= http.StatusMultipleChoices { + return fmt.Errorf("easypay refund HTTP %d: %s", status, summary) + } + + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary) + } + + lower := strings.ToLower(trimmed) + if strings.HasPrefix(lower, "" + } + if len(summary) > maxEasypayErrorSummary { + return summary[:maxEasypayErrorSummary] + "..." + } + return summary } func (e *EasyPay) resolveCID(paymentType string) string { @@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string { } func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) { + body, _, err := e.postRaw(ctx, endpoint, params) + return body, err +} + +func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) { form := url.Values{} for k, v := range params { form.Set(k, v) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { - return nil, err + return nil, 0, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := e.httpClient.Do(req) + client := e.httpClient + if client == nil { + client = &http.Client{Timeout: easypayHTTPTimeout} + } + resp, err := client.Do(req) if err != nil { - return nil, err + return nil, 0, err } defer func() { _ = resp.Body.Close() }() - return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize)) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize)) + if err != nil { + return nil, resp.StatusCode, err + } + return body, resp.StatusCode, nil } func easyPaySign(params map[string]string, pkey string) string { diff --git a/backend/internal/payment/provider/easypay_refund_test.go b/backend/internal/payment/provider/easypay_refund_test.go new file mode 100644 index 00000000..9e0e4942 --- /dev/null +++ b/backend/internal/payment/provider/easypay_refund_test.go @@ -0,0 +1,196 @@ +package provider + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeEasyPayAPIBase(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want string + }{ + {input: "https://zpayz.cn", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + if got := normalizeEasyPayAPIBase(tt.input); got != tt.want { + t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) { + t.Parallel() + + var gotPath string + var gotQuery url.Values + var gotForm url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQuery = r.URL.Query() + if err := r.ParseForm(); err != nil { + t.Errorf("ParseForm: %v", err) + } + gotForm = r.PostForm + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL+"/mapi.php") + resp, err := provider.Refund(context.Background(), payment.RefundRequest{ + TradeNo: "trade-123", + OrderID: "out-456", + Amount: "1.50", + }) + if err != nil { + t.Fatalf("Refund returned error: %v", err) + } + if resp == nil || resp.Status != payment.ProviderStatusSuccess { + t.Fatalf("Refund response = %+v, want success", resp) + } + if gotPath != "/api.php" { + t.Fatalf("refund path = %q, want /api.php", gotPath) + } + if gotQuery.Get("act") != "refund" { + t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act")) + } + for key, want := range map[string]string{ + "pid": "pid-1", + "key": "pkey-1", + "out_trade_no": "out-456", + "money": "1.50", + } { + if got := gotForm.Get(key); got != want { + t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm) + } + } + if got := gotForm.Get("trade_no"); got != "" { + t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm) + } +} + +func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) { + t.Parallel() + + var gotForms []url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api.php" { + t.Errorf("refund path = %q, want /api.php", r.URL.Path) + } + if r.URL.Query().Get("act") != "refund" { + t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act")) + } + if err := r.ParseForm(); err != nil { + t.Errorf("ParseForm: %v", err) + } + gotForms = append(gotForms, r.PostForm) + w.Header().Set("Content-Type", "application/json") + if len(gotForms) == 1 { + _, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`)) + return + } + _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL+"/mapi.php") + resp, err := provider.Refund(context.Background(), payment.RefundRequest{ + TradeNo: "trade-123", + OrderID: "out-456", + Amount: "1.50", + }) + if err != nil { + t.Fatalf("Refund returned error: %v", err) + } + if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" { + t.Fatalf("Refund response = %+v, want success with trade refund id", resp) + } + if len(gotForms) != 2 { + t.Fatalf("refund attempts = %d, want 2", len(gotForms)) + } + if got := gotForms[0].Get("out_trade_no"); got != "out-456" { + t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0]) + } + if got := gotForms[0].Get("trade_no"); got != "" { + t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0]) + } + if got := gotForms[1].Get("trade_no"); got != "trade-123" { + t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1]) + } + if got := gotForms[1].Get("out_trade_no"); got != "" { + t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1]) + } +} + +func TestEasyPayRefundResponseErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + want string + }{ + {name: "html response", statusCode: http.StatusOK, body: "bad config", want: "non-JSON response (HTTP 200): bad config"}, + {name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"}, + {name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"}, + {name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL) + _, err := provider.Refund(context.Background(), payment.RefundRequest{ + OrderID: "out-456", + Amount: "1.50", + }) + if err == nil { + t.Fatal("Refund returned nil error") + } + if !strings.Contains(err.Error(), tt.want) { + t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want) + } + }) + } +} + +func newTestEasyPay(t *testing.T, apiBase string) *EasyPay { + t.Helper() + + provider, err := NewEasyPay("test-instance", map[string]string{ + "pid": "pid-1", + "pkey": "pkey-1", + "apiBase": apiBase, + "notifyUrl": "https://example.com/notify", + "returnUrl": "https://example.com/return", + }) + if err != nil { + t.Fatalf("NewEasyPay: %v", err) + } + return provider +} diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 095305c2..e8b25c2b 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) { assert.Equal(t, 5, anth.Usage.OutputTokens) } +func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cached", + Model: "gpt-5.2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Cached response"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 54006, + OutputTokens: 123, + TotalTokens: 54129, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 50688, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929") + assert.Equal(t, 3318, anth.Usage.InputTokens) + assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens) + assert.Equal(t, 123, anth.Usage.OutputTokens) +} + +func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cached_clamp", + Model: "gpt-5.2", + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 5, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 150, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929") + assert.Equal(t, 0, anth.Usage.InputTokens) + assert.Equal(t, 150, anth.Usage.CacheReadInputTokens) + assert.Equal(t, 5, anth.Usage.OutputTokens) +} + func TestResponsesToAnthropic_ToolUse(t *testing.T) { resp := &ResponsesResponse{ ID: "resp_456", @@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) { assert.Equal(t, "tool_use", anth.Content[1].Type) assert.Equal(t, "call_1", anth.Content[1].ID) assert.Equal(t, "get_weather", anth.Content[1].Name) + assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input)) +} + +func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_read", + Model: "gpt-5.5", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_read", + Name: "Read", + Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 1) + assert.Equal(t, "tool_use", anth.Content[0].Type) + assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input)) +} + +func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_other", + Model: "gpt-5.5", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_other", + Name: "Search", + Arguments: `{"query":""}`, + }, + }, + } + + anth := ResponsesToAnthropic(resp, "claude-opus-4-6") + require.Len(t, anth.Content, 1) + assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input)) } func TestResponsesToAnthropic_Reasoning(t *testing.T) { @@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) { assert.Equal(t, "message_stop", events[1].Type) } +func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) { + state := NewResponsesEventToAnthropicState() + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 54006, + OutputTokens: 123, + TotalTokens: 54129, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 50688, + }, + }, + }, + }, state) + + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + assert.Equal(t, 3318, events[0].Usage.InputTokens) + assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens) + assert.Equal(t, 123, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) +} + func TestStreamingToolCall(t *testing.T) { state := NewResponsesEventToAnthropicState() @@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) { assert.Equal(t, "tool_use", events[0].Delta.StopReason) } +func TestStreamingReadToolDropsEmptyPages(t *testing.T) { + state := NewResponsesEventToAnthropicState() + + ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"}, + }, state) + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"}, + }, state) + require.Len(t, events, 1) + assert.Equal(t, "content_block_start", events[0].Type) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 0, + Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`, + }, state) + assert.Len(t, events, 0) + + events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.done", + OutputIndex: 0, + Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`, + }, state) + require.Len(t, events, 2) + assert.Equal(t, "content_block_delta", events[0].Type) + assert.Equal(t, "input_json_delta", events[0].Delta.Type) + assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON) + assert.Equal(t, "content_block_stop", events[1].Type) +} + func TestStreamingReasoning(t *testing.T) { state := NewResponsesEventToAnthropicState() @@ -835,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) { var tc map[string]any require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) assert.Equal(t, "function", tc["type"]) - fn, ok := tc["function"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "get_weather", fn["name"]) + assert.Equal(t, "get_weather", tc["name"]) + assert.NotContains(t, tc, "function") +} + +func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-5.2", + Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`), + ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`), + } + + resp, err := ResponsesToAnthropicRequest(req) + require.NoError(t, err) + + var tc map[string]string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "tool", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) +} + +func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-5.2", + Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`), + ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`), + } + + resp, err := ResponsesToAnthropicRequest(req) + require.NoError(t, err) + + var tc map[string]string + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "tool", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) } // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go index 485262e8..268f9f22 100644 --- a/backend/internal/pkg/apicompat/anthropic_to_responses.go +++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go @@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) { // {"type":"auto"} → "auto" // {"type":"any"} → "required" // {"type":"none"} → "none" -// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}} +// {"type":"tool","name":"X"} → {"type":"function","name":"X"} func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) { var tc struct { Type string `json:"type"` @@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage return json.Marshal("none") case "tool": return json.Marshal(map[string]any{ - "type": "function", - "function": map[string]string{"name": tc.Name}, + "type": "function", + "name": tc.Name, }) default: // Pass through unknown types as-is diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index c140449a..35d42999 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { var tc map[string]any require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) assert.Equal(t, "function", tc["type"]) + assert.Equal(t, "get_weather", tc["name"]) + assert.NotContains(t, tc, "function") } func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index c2725406..64ef5781 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R // // "auto" → "auto" // "none" → "none" -// {"name":"X"} → {"type":"function","function":{"name":"X"}} +// {"name":"X"} → {"type":"function","name":"X"} func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { // Try string first ("auto", "none", etc.) — pass through as-is. var s string @@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, return nil, err } return json.Marshal(map[string]any{ - "type": "function", - "function": map[string]string{"name": obj.Name}, + "type": "function", + "name": obj.Name, }) } diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index 5409a0f4..489ed238 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo Type: "tool_use", ID: fromResponsesCallID(item.CallID), Name: item.Name, - Input: json.RawMessage(item.Arguments), + Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments), }) case "web_search_call": toolUseID := "srvtoolu_" + item.ID @@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks) if resp.Usage != nil { - out.Usage = AnthropicUsage{ - InputTokens: resp.Usage.InputTokens, - OutputTokens: resp.Usage.OutputTokens, - } - if resp.Usage.InputTokensDetails != nil { - out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens - } + out.Usage = anthropicUsageFromResponsesUsage(resp.Usage) } return out } +func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage { + if usage == nil { + return AnthropicUsage{} + } + + cachedTokens := 0 + if usage.InputTokensDetails != nil { + cachedTokens = usage.InputTokensDetails.CachedTokens + } + + inputTokens := usage.InputTokens - cachedTokens + if inputTokens < 0 { + inputTokens = 0 + } + + return AnthropicUsage{ + InputTokens: inputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: cachedTokens, + } +} + func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string { switch status { case "incomplete": @@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom } } +func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage { + if name != "Read" || raw == "" { + return json.RawMessage(raw) + } + + var input map[string]json.RawMessage + if err := json.Unmarshal([]byte(raw), &input); err != nil { + return json.RawMessage(raw) + } + + if pages, ok := input["pages"]; !ok || string(pages) != `""` { + return json.RawMessage(raw) + } + + delete(input, "pages") + sanitized, err := json.Marshal(input) + if err != nil { + return json.RawMessage(raw) + } + return sanitized +} + // --------------------------------------------------------------------------- // Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter) // --------------------------------------------------------------------------- @@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct { ContentBlockIndex int ContentBlockOpen bool CurrentBlockType string // "text" | "thinking" | "tool_use" + CurrentToolName string + CurrentToolArgs string // OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index. OutputIndexToBlockIdx map[int]int @@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents( case "response.function_call_arguments.delta": return resToAnthHandleFuncArgsDelta(evt, state) case "response.function_call_arguments.done": - return resToAnthHandleBlockDone(state) + return resToAnthHandleFuncArgsDone(evt, state) case "response.output_item.done": return resToAnthHandleOutputItemDone(evt, state) case "response.reasoning_summary_text.delta": @@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE state.OutputIndexToBlockIdx[evt.OutputIndex] = idx state.ContentBlockOpen = true state.CurrentBlockType = "tool_use" + state.CurrentToolName = evt.Item.Name + state.CurrentToolArgs = "" events = append(events, AnthropicStreamEvent{ Type: "content_block_start", @@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve return nil } + if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" { + state.CurrentToolArgs += evt.Delta + return nil + } + blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex] if !ok { return nil @@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve }} } +func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { + if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" { + return resToAnthHandleBlockDone(state) + } + + raw := evt.Arguments + if raw == "" { + raw = state.CurrentToolArgs + } + sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw) + if len(sanitized) == 0 { + return closeCurrentBlock(state) + } + + idx := state.ContentBlockIndex + events := []AnthropicStreamEvent{{ + Type: "content_block_delta", + Index: &idx, + Delta: &AnthropicDelta{ + Type: "input_json_delta", + PartialJSON: string(sanitized), + }, + }} + events = append(events, closeCurrentBlock(state)...) + return events +} + func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent { if evt.Delta == "" { return nil @@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo stopReason := "end_turn" if evt.Response != nil { if evt.Response.Usage != nil { - state.InputTokens = evt.Response.Usage.InputTokens - state.OutputTokens = evt.Response.Usage.OutputTokens - if evt.Response.Usage.InputTokensDetails != nil { - state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens - } + usage := anthropicUsageFromResponsesUsage(evt.Response.Usage) + state.InputTokens = usage.InputTokens + state.OutputTokens = usage.OutputTokens + state.CacheReadInputTokens = usage.CacheReadInputTokens } switch evt.Response.Status { case "incomplete": @@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE idx := state.ContentBlockIndex state.ContentBlockOpen = false state.ContentBlockIndex++ + state.CurrentToolName = "" + state.CurrentToolArgs = "" return []AnthropicStreamEvent{{ Type: "content_block_stop", Index: &idx, diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go index 49426b88..8fa652f2 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { // "auto" → {"type":"auto"} // "required" → {"type":"any"} // "none" → {"type":"none"} -// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} +// {"type":"function","name":"X"} → {"type":"tool","name":"X"} +// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { // Try as string first var s string @@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage // Try as object with type=function var tc struct { Type string `json:"type"` + Name string `json:"name"` Function struct { Name string `json:"name"` } `json:"function"` } - if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { + if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" { + name := strings.TrimSpace(tc.Name) + if name == "" { + name = strings.TrimSpace(tc.Function.Name) + } + if name == "" { + return raw, nil + } return json.Marshal(map[string]string{ "type": "tool", - "name": tc.Function.Name, + "name": name, }) } diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go index 69e99dc5..cee12948 100644 --- a/backend/internal/pkg/httputil/body.go +++ b/backend/internal/pkg/httputil/body.go @@ -2,16 +2,28 @@ package httputil import ( "bytes" + "compress/gzip" + "compress/zlib" + "errors" + "fmt" "io" "net/http" + "strings" + + "github.com/klauspost/compress/zstd" ) const ( requestBodyReadInitCap = 512 requestBodyReadMaxInitCap = 1 << 20 + // maxDecompressedBodySize limits the decompressed request body to 64 MB + // to prevent decompression bomb attacks. + maxDecompressedBodySize = 64 << 20 ) -// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based +// on content length, transparently decoding any Content-Encoding the upstream +// client used to compress the body (zstd, gzip, deflate). func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { if req == nil || req.Body == nil { return nil, nil @@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { if _, err := io.Copy(buf, req.Body); err != nil { return nil, err } - return buf.Bytes(), nil + raw := buf.Bytes() + + enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding"))) + if enc == "" || enc == "identity" { + return raw, nil + } + + decoded, err := decompressRequestBody(enc, raw) + if err != nil { + return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err) + } + + req.Header.Del("Content-Encoding") + req.Header.Del("Content-Length") + req.ContentLength = int64(len(decoded)) + + return decoded, nil +} + +func decompressRequestBody(encoding string, raw []byte) ([]byte, error) { + switch encoding { + case "zstd": + dec, err := zstd.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer dec.Close() + return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize)) + case "gzip", "x-gzip": + gr, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer func() { _ = gr.Close() }() + return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize)) + case "deflate": + zr, err := zlib.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer func() { _ = zr.Close() }() + return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize)) + default: + return nil, errors.New("unsupported Content-Encoding") + } } diff --git a/backend/internal/pkg/httputil/body_test.go b/backend/internal/pkg/httputil/body_test.go new file mode 100644 index 00000000..ed8355d5 --- /dev/null +++ b/backend/internal/pkg/httputil/body_test.go @@ -0,0 +1,143 @@ +package httputil + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "net/http" + "strings" + "testing" + + "github.com/klauspost/compress/zstd" +) + +const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}` + +func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request { + t.Helper() + req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if encoding != "" { + req.Header.Set("Content-Encoding", encoding) + } + req.ContentLength = int64(len(body)) + return req +} + +func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) { + enc, _ := zstd.NewWriter(nil) + compressed := enc.EncodeAll([]byte(samplePayload), nil) + _ = enc.Close() + + req := newRequestWithBody(t, compressed, "zstd") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } + if req.Header.Get("Content-Encoding") != "" { + t.Fatalf("Content-Encoding should be cleared after decoding") + } + if req.ContentLength != int64(len(samplePayload)) { + t.Fatalf("ContentLength not updated: %d", req.ContentLength) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write([]byte(samplePayload)); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + + req := newRequestWithBody(t, buf.Bytes(), "gzip") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + if _, err := zw.Write([]byte(samplePayload)); err != nil { + t.Fatalf("zlib write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("zlib close: %v", err) + } + + req := newRequestWithBody(t, buf.Bytes(), "deflate") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "br") + _, err := ReadRequestBodyWithPrealloc(req) + if err == nil { + t.Fatal("expected error for unsupported encoding, got nil") + } + if !strings.Contains(err.Error(), "br") { + t.Fatalf("error should mention encoding, got %v", err) + } +} + +func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) { + req := newRequestWithBody(t, []byte("not actually zstd"), "zstd") + _, err := ReadRequestBodyWithPrealloc(req) + if err == nil { + t.Fatal("expected error for corrupt zstd body, got nil") + } +} + +func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Fatalf("expected nil body, got %q", got) + } +} + +func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) { + req := newRequestWithBody(t, []byte(samplePayload), "identity") + got, err := ReadRequestBodyWithPrealloc(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(got) != samplePayload { + t.Fatalf("body mismatch: got %q", got) + } +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index b249bb61..d1cea9eb 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi return true, nil } +func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error { + return nil +} + func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go index e3dd56b8..ef89e5b6 100644 --- a/backend/internal/repository/affiliate_repo.go +++ b/backend/internal/repository/affiliate_repo.go @@ -86,17 +86,21 @@ func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID return bound, nil } -func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) { +func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) { if amount <= 0 { return false, nil } var applied bool err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { - res, err := txClient.ExecContext(txCtx, - "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2", - amount, inviterID, - ) + // freezeHours > 0: add to frozen quota; == 0: add to available quota directly + var updateSQL string + if freezeHours > 0 { + updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2" + } else { + updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2" + } + res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID) if err != nil { return err } @@ -106,10 +110,19 @@ func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, invite return nil } - if _, err = txClient.ExecContext(txCtx, ` + if freezeHours > 0 { + if _, err = txClient.ExecContext(txCtx, ` +INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at) +VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`, + inviterID, amount, inviteeUserID, freezeHours); err != nil { + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } + } else { + if _, err = txClient.ExecContext(txCtx, ` INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at) VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil { - return fmt.Errorf("insert affiliate accrue ledger: %w", err) + return fmt.Errorf("insert affiliate accrue ledger: %w", err) + } } applied = true @@ -121,6 +134,76 @@ VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); return applied, nil } +func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) { + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, + `SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`, + inviterID, inviteeUserID) + if err != nil { + return 0, fmt.Errorf("query accrued rebate from invitee: %w", err) + } + defer func() { _ = rows.Close() }() + var total float64 + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, rows.Close() +} + +func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) { + var thawed float64 + err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + var err error + thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID) + return err + }) + return thawed, err +} + +// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx. +func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) { + rows, err := txClient.QueryContext(txCtx, ` +WITH matured AS ( + UPDATE user_affiliate_ledger + SET frozen_until = NULL, updated_at = NOW() + WHERE user_id = $1 + AND frozen_until IS NOT NULL + AND frozen_until <= NOW() + RETURNING amount +) +SELECT COALESCE(SUM(amount), 0) FROM matured`, userID) + if err != nil { + return 0, fmt.Errorf("thaw frozen quota: %w", err) + } + defer func() { _ = rows.Close() }() + + var thawed float64 + if rows.Next() { + if err := rows.Scan(&thawed); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if thawed <= 0 { + return 0, nil + } + + _, err = txClient.ExecContext(txCtx, ` +UPDATE user_affiliates +SET aff_quota = aff_quota + $1, + aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0), + updated_at = NOW() +WHERE user_id = $2`, thawed, userID) + if err != nil { + return 0, fmt.Errorf("move thawed quota: %w", err) + } + return thawed, nil +} + func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) { var transferred float64 var newBalance float64 @@ -130,6 +213,11 @@ func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID return err } + // Thaw any matured frozen quota before transfer. + if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil { + return fmt.Errorf("thaw before transfer: %w", err) + } + rows, err := txClient.QueryContext(txCtx, ` WITH claimed AS ( SELECT aff_quota::double precision AS amount @@ -211,10 +299,16 @@ func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, SELECT ua.user_id, COALESCE(u.email, ''), COALESCE(u.username, ''), - ua.created_at + ua.created_at, + COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate FROM user_affiliates ua LEFT JOIN users u ON u.id = ua.user_id +LEFT JOIN user_affiliate_ledger ual + ON ual.user_id = $1 + AND ual.source_user_id = ua.user_id + AND ual.action = 'accrue' WHERE ua.inviter_id = $1 +GROUP BY ua.user_id, u.email, u.username, ua.created_at ORDER BY ua.created_at DESC LIMIT $2`, inviterID, limit) if err != nil { @@ -226,7 +320,7 @@ LIMIT $2`, inviterID, limit) for rows.Next() { var item service.AffiliateInvitee var createdAt time.Time - if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil { + if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil { return nil, err } item.CreatedAt = &createdAt @@ -299,6 +393,7 @@ SELECT user_id, inviter_id, aff_count, aff_quota::double precision, + aff_frozen_quota::double precision, aff_history_quota::double precision, created_at, updated_at @@ -326,6 +421,7 @@ WHERE user_id = $1`, userID) &inviterID, &out.AffCount, &out.AffQuota, + &out.AffFrozenQuota, &out.AffHistoryQuota, &out.CreatedAt, &out.UpdatedAt, @@ -351,6 +447,7 @@ SELECT user_id, inviter_id, aff_count, aff_quota::double precision, + aff_frozen_quota::double precision, aff_history_quota::double precision, created_at, updated_at @@ -380,6 +477,7 @@ LIMIT 1`, strings.ToUpper(strings.TrimSpace(code))) &inviterID, &out.AffCount, &out.AffQuota, + &out.AffFrozenQuota, &out.AffHistoryQuota, &out.CreatedAt, &out.UpdatedAt, diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go index 369f57cf..697a193b 100644 --- a/backend/internal/repository/affiliate_repo_integration_test.go +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -125,7 +125,7 @@ func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { require.NoError(t, err) require.True(t, bound, "invitee must bind to inviter") - applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5) + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0) require.NoError(t, err) require.True(t, applied, "AccrueQuota must report applied=true") diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index b1070aba..f1a42ef7 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -24,6 +24,49 @@ const ( defaultSchedulerSnapshotMGetChunkSize = 128 defaultSchedulerSnapshotWriteChunkSize = 256 + + // snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。 + // 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。 + snapshotGraceTTLSeconds = 60 +) + +var ( + // activateSnapshotScript 原子 CAS 切换快照版本。 + // 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。 + // 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。 + // + // KEYS[1] = activeKey (sched:active:{bucket}) + // KEYS[2] = readyKey (sched:ready:{bucket}) + // KEYS[3] = bucketSetKey (sched:buckets) + // KEYS[4] = snapshotKey (新写入的快照 key) + // ARGV[1] = 新版本号字符串 + // ARGV[2] = bucket 字符串 (用于 SADD) + // ARGV[3] = 快照 key 前缀 (用于构造旧快照 key) + // ARGV[4] = 宽限期 TTL 秒数 + // + // 返回 1 = 已激活, 0 = 版本过旧未激活 + activateSnapshotScript = redis.NewScript(` +local currentActive = redis.call('GET', KEYS[1]) +local newVersion = tonumber(ARGV[1]) + +if currentActive ~= false then + local curVersion = tonumber(currentActive) + if curVersion and newVersion < curVersion then + redis.call('DEL', KEYS[4]) + return 0 + end +end + +redis.call('SET', KEYS[1], ARGV[1]) +redis.call('SET', KEYS[2], '1') +redis.call('SADD', KEYS[3], ARGV[2]) + +if currentActive ~= false and currentActive ~= ARGV[1] then + redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4])) +end + +return 1 +`) ) type schedulerCache struct { @@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul } func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { - activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) - oldActive, _ := c.rdb.Get(ctx, activeKey).Result() - + // Phase 1: 分配新版本号并写入快照数据。 + // INCR 保证每个调用方获得唯一递增版本号。 + // 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。 versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket) version, err := c.rdb.Incr(ctx, versionKey).Result() if err != nil { @@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul return err } - pipe := c.rdb.Pipeline() if len(accounts) > 0 { // 使用序号作为 score,保持数据库返回的排序语义。 members := make([]redis.Z, 0, len(accounts)) @@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul Member: strconv.FormatInt(account.ID, 10), }) } + pipe := c.rdb.Pipeline() for start := 0; start < len(members); start += c.writeChunkSize { end := start + c.writeChunkSize if end > len(members) { @@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul } pipe.ZAdd(ctx, snapshotKey, members[start:end]...) } - } else { - pipe.Del(ctx, snapshotKey) - } - pipe.Set(ctx, activeKey, versionStr, 0) - pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0) - pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String()) - if _, err := pipe.Exec(ctx); err != nil { - return err + if _, err := pipe.Exec(ctx); err != nil { + return err + } } - if oldActive != "" && oldActive != versionStr { - _ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err() + // Phase 2: 原子 CAS 激活版本。 + // Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针, + // 防止并发写入导致版本回滚。 + // 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。 + activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) + readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket) + snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode) + + keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey} + args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds} + + _, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result() + if err != nil { + return err } return nil @@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result() } +func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error { + key := schedulerBucketKey(schedulerLockPrefix, bucket) + return c.rdb.Del(ctx, key).Err() +} + func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result() if err != nil { @@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account { SessionWindowStart: account.SessionWindowStart, SessionWindowEnd: account.SessionWindowEnd, SessionWindowStatus: account.SessionWindowStatus, + AccountGroups: filterSchedulerAccountGroups(account.AccountGroups), + GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups), Credentials: filterSchedulerCredentials(account.Credentials), Extra: filterSchedulerExtra(account.Extra), } } +func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup { + if len(accountGroups) == 0 { + return nil + } + + filtered := make([]service.AccountGroup, 0, len(accountGroups)) + for _, ag := range accountGroups { + if ag.GroupID <= 0 { + continue + } + filtered = append(filtered, service.AccountGroup{ + AccountID: ag.AccountID, + GroupID: ag.GroupID, + Priority: ag.Priority, + CreatedAt: ag.CreatedAt, + }) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 { + if len(groupIDs) == 0 && len(accountGroups) == 0 { + return nil + } + + seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups)) + filtered := make([]int64, 0, len(groupIDs)+len(accountGroups)) + for _, id := range groupIDs { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + filtered = append(filtered, id) + } + for _, ag := range accountGroups { + if ag.GroupID <= 0 { + continue + } + if _, ok := seen[ag.GroupID]; ok { + continue + } + seen[ag.GroupID] = struct{}{} + filtered = append(filtered, ag.GroupID) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + func filterSchedulerCredentials(credentials map[string]any) map[string]any { if len(credentials) == 0 { return nil diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go index 134a6a07..948c2c73 100644 --- a/backend/internal/repository/scheduler_cache_integration_test.go +++ b/backend/internal/repository/scheduler_cache_integration_test.go @@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) SessionWindowStart: &now, SessionWindowEnd: &windowEnd, SessionWindowStatus: "active", + GroupIDs: []int64{bucket.GroupID}, + AccountGroups: []service.AccountGroup{ + { + AccountID: 101, + GroupID: bucket.GroupID, + Priority: 5, + Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"}, + }, + }, } require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account})) @@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) require.Equal(t, 4, got.GetMaxSessions()) require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes()) require.Nil(t, got.Extra["unused_large_field"]) + require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs) + require.Len(t, got.AccountGroups, 1) + require.Equal(t, account.ID, got.AccountGroups[0].AccountID) + require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID) + require.Nil(t, got.AccountGroups[0].Group) full, err := cache.GetAccount(ctx, account.ID) require.NoError(t, err) require.NotNil(t, full) require.Equal(t, "secret-access-token", full.GetCredential("access_token")) require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob")) + require.Len(t, full.AccountGroups, 1) + require.NotNil(t, full.AccountGroups[0].Group) } diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go index d302c5ea..32dda0a8 100644 --- a/backend/internal/repository/scheduler_cache_unit_test.go +++ b/backend/internal/repository/scheduler_cache_unit_test.go @@ -56,3 +56,43 @@ func TestBuildSchedulerMetadataAccount_KeepsModelRateLimits(t *testing.T) { require.Equal(t, modelLimits, got.Extra["model_rate_limits"], "model_rate_limits must be carried into scheduler snapshot for rate-limit-aware selection") require.Nil(t, got.Extra["unused_large_field"]) } + +func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) { + account := service.Account{ + ID: 42, + Platform: service.PlatformAnthropic, + GroupIDs: []int64{7, 9, 7, 0}, + AccountGroups: []service.AccountGroup{ + { + AccountID: 42, + GroupID: 7, + Priority: 2, + Account: &service.Account{ID: 42, Name: "drop-from-metadata"}, + Group: &service.Group{ID: 7, Name: "drop-from-metadata"}, + }, + { + AccountID: 42, + GroupID: 11, + Priority: 3, + Group: &service.Group{ID: 11, Name: "drop-from-metadata"}, + }, + { + AccountID: 42, + GroupID: 0, + Priority: 4, + }, + }, + } + + got := buildSchedulerMetadataAccount(account) + + require.Equal(t, []int64{7, 9, 11}, got.GroupIDs) + require.Len(t, got.AccountGroups, 2) + require.Equal(t, int64(42), got.AccountGroups[0].AccountID) + require.Equal(t, int64(7), got.AccountGroups[0].GroupID) + require.Equal(t, 2, got.AccountGroups[0].Priority) + require.Nil(t, got.AccountGroups[0].Account) + require.Nil(t, got.AccountGroups[0].Group) + require.Equal(t, int64(11), got.AccountGroups[1].GroupID) + require.Nil(t, got.Groups) +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d605c52b..fabf3b5d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -716,6 +716,9 @@ func TestAPIContracts(t *testing.T) { "default_concurrency": 5, "default_balance": 1.25, "affiliate_rebate_rate": 20, + "affiliate_rebate_freeze_hours": 0, + "affiliate_rebate_duration_days": 0, + "affiliate_rebate_per_invitee_cap": 0, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -737,6 +740,7 @@ func TestAPIContracts(t *testing.T) { "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, "enable_cch_signing": false, + "enable_anthropic_cache_ttl_1h_injection": false, "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, @@ -745,6 +749,16 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": true, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": true, + "openai_fast_policy_settings": { + "rules": [ + { + "service_tier": "priority", + "action": "filter", + "scope": "all", + "fallback_action": "pass" + } + ] + }, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -898,6 +912,9 @@ func TestAPIContracts(t *testing.T) { "default_concurrency": 0, "default_balance": 0, "affiliate_rebate_rate": 20, + "affiliate_rebate_freeze_hours": 0, + "affiliate_rebate_duration_days": 0, + "affiliate_rebate_per_invitee_cap": 0, "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, @@ -918,12 +935,23 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "enable_cch_signing": false, + "enable_anthropic_cache_ttl_1h_injection": false, "web_search_emulation_enabled": false, "payment_visible_method_alipay_source": "", "payment_visible_method_wxpay_source": "", "payment_visible_method_alipay_enabled": false, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": false, + "openai_fast_policy_settings": { + "rules": [ + { + "service_tier": "priority", + "action": "filter", + "scope": "all", + "fallback_action": "pass" + } + ] + }, "payment_enabled": false, "payment_min_amount": 0, "payment_max_amount": 0, diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index ce2a5dbe..cab0215b 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -66,6 +66,7 @@ func isOpenAIImageModel(model string) bool { type AccountTestService struct { accountRepo AccountRepository geminiTokenProvider *GeminiTokenProvider + claudeTokenProvider *ClaudeTokenProvider antigravityGatewayService *AntigravityGatewayService windsurfChatService *WindsurfChatService httpUpstream HTTPUpstream @@ -77,6 +78,7 @@ type AccountTestService struct { func NewAccountTestService( accountRepo AccountRepository, geminiTokenProvider *GeminiTokenProvider, + claudeTokenProvider *ClaudeTokenProvider, antigravityGatewayService *AntigravityGatewayService, windsurfChatService *WindsurfChatService, httpUpstream HTTPUpstream, @@ -86,6 +88,7 @@ func NewAccountTestService( return &AccountTestService{ accountRepo: accountRepo, geminiTokenProvider: geminiTokenProvider, + claudeTokenProvider: claudeTokenProvider, antigravityGatewayService: antigravityGatewayService, windsurfChatService: windsurfChatService, httpUpstream: httpUpstream, @@ -223,6 +226,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if account.IsBedrock() { return s.testBedrockAccountConnection(c, ctx, account, testModelID, prompt) } + if account.Type == AccountTypeServiceAccount { + return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID) + } // Determine authentication method and API URL var authToken string @@ -326,6 +332,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.processClaudeStream(c, resp.Body) } +func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { + if mappedModel, matched := account.ResolveMappedModel(testModelID); matched { + testModelID = mappedModel + } else { + testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID)) + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + payload, err := createTestPayload(testModelID, "") + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create test payload") + } + payloadBytes, _ := json.Marshal(payload) + vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error())) + } + + if s.claudeTokenProvider == nil { + return s.sendErrorAndEnd(c, "Claude token provider not configured") + } + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error())) + } + + fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error())) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)) + if resp.StatusCode == http.StatusForbidden { + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } + return s.sendErrorAndEnd(c, errMsg) + } + + return s.processClaudeStream(c, resp.Body) +} + // testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string, prompt string) error { region := bedrockRuntimeRegion(account) @@ -728,8 +802,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account testModelID = geminicli.DefaultTestModel } - // For API Key accounts with model mapping, map the model - if account.Type == AccountTypeAPIKey { + // For static upstream credentials with model mapping, map the model + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mapping := account.GetModelMapping() if len(mapping) > 0 { if mappedModel, exists := mapping[testModelID]; exists { @@ -757,6 +831,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) case AccountTypeOAuth: req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) + case AccountTypeServiceAccount: + req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload) default: return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -948,6 +1024,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) } +func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { + if s.geminiTokenProvider == nil { + return nil, fmt.Errorf("gemini token provider not configured") + } + accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("failed to get service account access token: %w", err) + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + return req, nil +} + // buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { var inner map[string]any @@ -1286,7 +1383,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations" + apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint) // Set SSE headers c.Writer.Header().Set("Content-Type", "text/event-stream") diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go index 80a2fc31..257159c4 100644 --- a/backend/internal/service/account_test_service_openai_image_test.go +++ b/backend/internal/service/account_test_service_openai_image_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -48,3 +49,42 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") require.Contains(t, rec.Body.String(), "\"success\":true") } + +func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{}, + } + account := &Account{ + ID: 54, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat") + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") + require.Contains(t, rec.Body.String(), "\"success\":true") +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 147ee4e1..b854c16e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -9,6 +9,7 @@ import ( "log/slog" "net/http" "sort" + "strconv" "strings" "time" @@ -58,6 +59,7 @@ type AdminService interface { // API Key management (admin) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) + AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) // ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限 ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) @@ -291,6 +293,7 @@ type UpdateAccountInput struct { // BulkUpdateAccountsInput describes the payload for bulk updating accounts. type BulkUpdateAccountsInput struct { AccountIDs []int64 + Filters *BulkUpdateAccountFilters Name string ProxyID *int64 Concurrency *int @@ -307,6 +310,15 @@ type BulkUpdateAccountsInput struct { SkipMixedChannelCheck bool } +type BulkUpdateAccountFilters struct { + Platform string + Type string + Status string + Group string + Search string + PrivacyMode string +} + // BulkUpdateAccountResult captures the result for a single account update. type BulkUpdateAccountResult struct { AccountID int64 `json:"account_id"` @@ -1961,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i return result, nil } +// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows. +func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("reset api key rate limit usage: %w", err) + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } + return apiKey, nil +} + // ReplaceUserGroup 替换用户的专属分组 func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { if oldGroupID == newGroupID { @@ -2286,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U // BulkUpdateAccounts updates multiple accounts in one request. // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { + if len(input.AccountIDs) == 0 && input.Filters != nil { + accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters) + if err != nil { + return nil, err + } + input.AccountIDs = accountIDs + } + result := &BulkUpdateAccountsResult{ SuccessIDs: make([]int64, 0, len(input.AccountIDs)), FailedIDs: make([]int64, 0, len(input.AccountIDs)), @@ -2401,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp return result, nil } +func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) { + if filters == nil { + return nil, nil + } + + groupID := int64(0) + switch strings.TrimSpace(filters.Group) { + case "": + case "ungrouped": + groupID = AccountListGroupUngrouped + default: + parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid group filter: %w", err) + } + groupID = parsedGroupID + } + + const pageSize = 500 + page := 1 + accountIDs := make([]int64, 0, pageSize) + + for { + accounts, total, err := s.ListAccounts( + ctx, + page, + pageSize, + filters.Platform, + filters.Type, + filters.Status, + filters.Search, + groupID, + filters.PrivacyMode, + "", + "", + ) + if err != nil { + return nil, err + } + for _, account := range accounts { + accountIDs = append(accountIDs, account.ID) + } + if int64(len(accountIDs)) >= total || len(accounts) == 0 { + return accountIDs, nil + } + page++ + } +} + func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { if err := s.accountRepo.Delete(ctx, id); err != nil { return err diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 4845d87c..df415295 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -5,8 +5,10 @@ package service import ( "context" "errors" + "reflect" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" ) @@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct { getByIDCalled []int64 listByGroupData map[int64][]Account listByGroupErr map[int64]error + listData []Account + listResult *pagination.PaginationResult + listErr error + listCalled bool + lastListParams pagination.PaginationParams + lastListFilters struct { + platform string + accountType string + status string + search string + groupID int64 + privacyMode string + } } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in return nil, nil } +func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { + s.listCalled = true + s.lastListParams = params + s.lastListFilters.platform = platform + s.lastListFilters.accountType = accountType + s.lastListFilters.status = status + s.lastListFilters.search = search + s.lastListFilters.groupID = groupID + s.lastListFilters.privacyMode = privacyMode + if s.listErr != nil { + return nil, nil, s.listErr + } + if s.listResult != nil { + return s.listData, s.listResult, nil + } + return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon // No BindGroups should have been called since the check runs before any write. require.Empty(t, repo.bindGroupsCalls) } + +func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + listData: []Account{ + {ID: 7}, + {ID: 11}, + }, + listResult: &pagination.PaginationResult{Total: 2}, + } + svc := &adminServiceImpl{accountRepo: repo} + + schedulable := true + input := &BulkUpdateAccountsInput{ + Schedulable: &schedulable, + } + + filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters") + require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update") + require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field") + + filtersValue := reflect.New(filtersField.Type().Elem()) + filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI) + filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth) + filtersValue.Elem().FieldByName("Status").SetString(StatusActive) + filtersValue.Elem().FieldByName("Group").SetString("12") + filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked) + filtersValue.Elem().FieldByName("Search").SetString("bulk-target") + filtersField.Set(filtersValue) + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters") + require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform) + require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType) + require.Equal(t, StatusActive, repo.lastListFilters.status) + require.Equal(t, "bulk-target", repo.lastListFilters.search) + require.Equal(t, int64(12), repo.lastListFilters.groupID) + require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode) + require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs) + require.Equal(t, 2, result.Success) + require.Equal(t, 0, result.Failed) + require.Equal(t, []int64{7, 11}, result.SuccessIDs) +} diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index aca32076..5a4e91e7 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -65,16 +65,18 @@ type AffiliateSummary struct { InviterID *int64 `json:"inviter_id,omitempty"` AffCount int `json:"aff_count"` AffQuota float64 `json:"aff_quota"` + AffFrozenQuota float64 `json:"aff_frozen_quota"` AffHistoryQuota float64 `json:"aff_history_quota"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type AffiliateInvitee struct { - UserID int64 `json:"user_id"` - Email string `json:"email"` - Username string `json:"username"` - CreatedAt *time.Time `json:"created_at,omitempty"` + UserID int64 `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` + CreatedAt *time.Time `json:"created_at,omitempty"` + TotalRebate float64 `json:"total_rebate"` } type AffiliateDetail struct { @@ -83,6 +85,7 @@ type AffiliateDetail struct { InviterID *int64 `json:"inviter_id,omitempty"` AffCount int `json:"aff_count"` AffQuota float64 `json:"aff_quota"` + AffFrozenQuota float64 `json:"aff_frozen_quota"` AffHistoryQuota float64 `json:"aff_history_quota"` // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例: // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。 @@ -95,7 +98,9 @@ type AffiliateRepository interface { EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) - AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) + AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) + GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) + ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error) @@ -160,6 +165,12 @@ func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64 } func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) { + // Lazy thaw: move any matured frozen quota to available before reading. + if s != nil && s.repo != nil { + // best-effort: thaw failure is non-fatal + _, _ = s.repo.ThawFrozenQuota(ctx, userID) + } + summary, err := s.EnsureUserAffiliate(ctx, userID) if err != nil { return nil, err @@ -174,6 +185,7 @@ func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) InviterID: summary.InviterID, AffCount: summary.AffCount, AffQuota: summary.AffQuota, + AffFrozenQuota: summary.AffFrozenQuota, AffHistoryQuota: summary.AffHistoryQuota, EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary), Invitees: invitees, @@ -250,13 +262,43 @@ func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID if err != nil { return 0, err } + // 有效期检查:超过返利有效期后不再产生返利 + if s.settingService != nil { + if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 { + if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) { + return 0, nil + } + } + } + rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary) rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8) if rebate <= 0 { return 0, nil } - applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate) + // 单人上限检查:精确截断到剩余额度 + if s.settingService != nil { + if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 { + existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID) + if err != nil { + return 0, err + } + if existing >= perInviteeCap { + return 0, nil + } + if remaining := perInviteeCap - existing; rebate > remaining { + rebate = roundTo(remaining, 8) + } + } + } + + var freezeHours int + if s.settingService != nil { + freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx) + } + + applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours) if err != nil { return 0, err } diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index a18cf39c..9815f31b 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -175,6 +175,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( user *User, invitationCode string, signupSource string, + affiliateCode string, ) error { if s == nil || user == nil || user.ID <= 0 { return ErrServiceUnavailable @@ -194,6 +195,7 @@ func (s *AuthService) FinalizeOAuthEmailAccount( s.updateOAuthSignupSource(ctx, user.ID, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) return nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 08b0f4b7..b1adf071 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -563,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 // invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) { +// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -666,6 +667,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -683,6 +685,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -777,6 +780,22 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource } } +// bindOAuthAffiliate initializes the affiliate profile and binds the inviter +// for an OAuth-registered user. Failures are logged but never block registration. +func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) { + if s.affiliateService == nil || userID <= 0 { + return + } + if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err) + } + if code := strings.TrimSpace(affiliateCode); code != "" { + if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err) + } + } +} + func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { if user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index c1ad6240..acc44a38 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "") require.NoError(t, err) require.NotNil(t, tokenPair) require.NotNil(t, user) @@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "") require.NoError(t, err) require.NotNil(t, tokenPair) require.Equal(t, existing.ID, user.ID) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 4e695eb9..050db55b 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -508,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } +// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key. +func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + if s.cache == nil { + return nil + } + if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err) + return err + } + return nil +} + // ============================================ // API Key 限速缓存方法 // ============================================ diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index 82fa31c4..d70379c1 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -17,7 +17,7 @@ const ( // ClaudeTokenCache token cache interface. type ClaudeTokenCache = GeminiTokenCache -// ClaudeTokenProvider manages access_token for Claude OAuth accounts. +// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts. type ClaudeTokenProvider struct { accountRepo AccountRepository tokenCache ClaudeTokenCache @@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { - return "", errors.New("not an anthropic oauth account") + if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) { + return "", errors.New("not an anthropic oauth or service account") + } + if account.Type == AccountTypeServiceAccount { + return p.getServiceAccountAccessToken(ctx, account) } cacheKey := ClaudeTokenCacheKey(account) @@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { + return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account) +} diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go index 3e21f6f4..d4a4a14a 100644 --- a/backend/internal/service/claude_token_provider_test.go +++ b/backend/internal/service/claude_token_provider_test.go @@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A return "", errors.New("account is nil") } if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth { - return "", errors.New("not an anthropic oauth account") + return "", errors.New("not an anthropic oauth or service account") } cacheKey := ClaudeTokenCacheKey(account) @@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } @@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } @@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an anthropic oauth account") + require.Contains(t, err.Error(), "not an anthropic oauth or service account") require.Empty(t, token) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 1c8e7cc9..9793b6b2 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -20,10 +20,15 @@ const ( // Affiliate rebate settings const ( - AffiliateRebateRateDefault = 20.0 - AffiliateRebateRateMin = 0.0 - AffiliateRebateRateMax = 100.0 - AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 + AffiliateRebateRateDefault = 20.0 + AffiliateRebateRateMin = 0.0 + AffiliateRebateRateMax = 100.0 + AffiliateEnabledDefault = false // 邀请返利总开关默认关闭 + AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容) + AffiliateRebateFreezeHoursMax = 720 // 最大 30 天 + AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效 + AffiliateRebateDurationDaysMax = 3650 // ~10 年 + AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限 ) // Platform constants @@ -37,11 +42,12 @@ const ( // Account type constants const ( - AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 - AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) - AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) + AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI) ) // Redeem type constants @@ -98,6 +104,9 @@ const ( SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关 SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100) + SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结) + SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久) + SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -299,6 +308,12 @@ const ( // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. SettingKeyBetaPolicySettings = "beta_policy_settings" + // SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI + // service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but + // targets OpenAI's body-level service_tier field instead of Claude's + // anthropic-beta header. + SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings" + // ========================= // Claude Code Version Check // ========================= @@ -322,6 +337,8 @@ const ( SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" // SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false) SettingKeyEnableCCHSigning = "enable_cch_signing" + // SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false) + SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection" // Balance Low Notification SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 diff --git a/backend/internal/service/gateway_anthropic_vertex_service_account_test.go b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go new file mode 100644 index 00000000..aa779805 --- /dev/null +++ b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go @@ -0,0 +1,68 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("Anthropic-Version", "2023-06-01") + c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + + account := &Account{ + ID: 301, + Platform: PlatformAnthropic, + Type: AccountTypeServiceAccount, + Credentials: map[string]any{ + "project_id": "vertex-proj", + "location": "us-east5", + }, + } + body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`) + + svc := &GatewayService{} + req, err := svc.buildUpstreamRequest( + context.Background(), + c, + account, + body, + "vertex-token", + "service_account", + "claude-sonnet-4-5@20250929", + false, + false, + ) + require.NoError(t, err) + require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String()) + require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization")) + require.Empty(t, getHeaderRaw(req.Header, "x-api-key")) + require.Empty(t, getHeaderRaw(req.Header, "anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta")) + + got := readRequestBodyForTest(t, req) + require.Equal(t, "", gjson.GetBytes(got, "model").String()) + require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String()) + require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String()) +} + +func readRequestBodyForTest(t *testing.T, req *http.Request) []byte { + t.Helper() + require.NotNil(t, req.Body) + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + return body +} diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go index e6c9de7d..e0c3cafd 100644 --- a/backend/internal/service/gateway_body_order_test.go +++ b/backend/internal/service/gateway_body_order_test.go @@ -1,13 +1,91 @@ package service import ( + "context" + "errors" "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) +type gatewayTTLSettingRepo struct { + data map[string]string +} + +func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) { + return nil, ErrSettingNotFound +} + +func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) { + if r == nil { + return "", ErrSettingNotFound + } + v, ok := r.data[key] + if !ok { + return "", ErrSettingNotFound + } + return v, nil +} + +func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error { + if r == nil { + return errors.New("setting repo is nil") + } + if r.data == nil { + r.data = map[string]string{} + } + r.data[key] = value + return nil +} + +func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + if r == nil { + return result, nil + } + for _, key := range keys { + if v, ok := r.data[key]; ok { + result[key] = v + } + } + return result, nil +} + +func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error { + if r == nil { + return errors.New("setting repo is nil") + } + if r.data == nil { + r.data = map[string]string{} + } + for key, value := range settings { + r.data[key] = value + } + return nil +} + +func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string) + if r == nil { + return result, nil + } + for key, value := range r.data { + result[key] = value + } + return result, nil +} + +func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error { + if r != nil { + delete(r.data, key) + } + return nil +} + func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) { t.Helper() @@ -71,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) { assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`) require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`)) } + +func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) { + body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`) + + result := injectAnthropicCacheControlTTL1h(body) + resultStr := string(result) + + assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`) + require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String()) + require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) + require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists()) + require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) + require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String()) + require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String()) +} + +func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) { + repo := &gatewayTTLSettingRepo{data: map[string]string{ + SettingKeyEnableAnthropicCacheTTL1hInjection: "true", + }} + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + svc := &GatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth} + + target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account) + require.True(t, ok) + require.Equal(t, cacheTTLTarget5m, target) + + account.Extra = map[string]any{ + "cache_ttl_override_enabled": true, + "cache_ttl_override_target": "1h", + } + target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account) + require.True(t, ok) + require.Equal(t, cacheTTLTarget1h, target) +} + +func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) { + repo := &gatewayTTLSettingRepo{data: map[string]string{ + SettingKeyEnableAnthropicCacheTTL1hInjection: "true", + }} + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + svc := &GatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } + + require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth})) + require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken})) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey})) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth})) + + repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false" + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{}) + require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth})) +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index c531667e..7ac77f77 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions( // 4. Model mapping mappedModel := originalModel - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } - if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel)) + if normalized != originalModel { + mappedModel = normalized + } + } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(originalModel) if normalized != originalModel { mappedModel = normalized diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 647193d6..8f8a1e94 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses( // 4. Model mapping mappedModel := originalModel reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } - if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel)) + if normalized != originalModel { + mappedModel = normalized + } + } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(originalModel) if normalized != originalModel { mappedModel = normalized diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 701e56b2..319d4d0c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -11,6 +11,7 @@ import ( "io" "log/slog" mathrand "math/rand" + "net" "net/http" "net/url" "os" @@ -20,6 +21,7 @@ import ( "strconv" "strings" "sync/atomic" + "syscall" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -63,6 +65,11 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +const ( + cacheTTLTarget5m = "5m" + cacheTTLTarget1h = "1h" +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} @@ -335,9 +342,8 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( - sseDataRe = regexp.MustCompile(`^data:\s*`) - claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) - claudeCodeUserAgentRe = regexp.MustCompile(`^claude-(?:cli|code)/\d+\.\d+\.\d+`) + sseDataRe = regexp.MustCompile(`^data:\s*`) + claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -677,15 +683,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { - if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + uid := ParseMetadataUserID(parsed.MetadataUserID) + if uid != nil && uid.SessionID != "" { + slog.Info("sticky.hash_source", + "source", "metadata_user_id", + "session_id", uid.SessionID, + "device_id", uid.DeviceID, + "is_new_format", uid.IsNewFormat, + ) return uid.SessionID } + slog.Info("sticky.hash_metadata_parse_failed", + "metadata_user_id", parsed.MetadataUserID, + "parsed_nil", uid == nil, + ) } // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 cacheableContent := s.extractCacheableContent(parsed) if cacheableContent != "" { - return s.hashContent(cacheableContent) + hash := s.hashContent(cacheableContent) + slog.Info("sticky.hash_source", + "source", "cacheable_content", + "hash", hash, + ) + return hash } // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 @@ -725,7 +747,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } } if combined.Len() > 0 { - return s.hashContent(combined.String()) + hash := s.hashContent(combined.String()) + slog.Info("sticky.hash_source", + "source", "message_content_fallback", + "hash", hash, + "content_len", combined.Len(), + ) + return hash } return "" @@ -1432,14 +1460,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } var stickyAccountID int64 + var stickySource string if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch + stickySource = "prefetch" } else if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID + stickySource = "cache" } } + // [DEBUG-STICKY] 调度器入口日志 + slog.Info("sticky.scheduler_entry", + "group_id", derefGroupID(groupID), + "session_hash", shortSessionHash(sessionHash), + "sticky_account_id", stickyAccountID, + "sticky_source", stickySource, + "model", requestedModel, + "load_batch", cfg.LoadBatchEnabled, + "has_concurrency_svc", s.concurrencyService != nil, + "excluded_count", len(excludedIDs), + ) + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1615,6 +1658,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 if sessionHash != "" && stickyAccountID > 0 { + slog.Debug("sticky.layer1_5_checking", + "sticky_account_id", stickyAccountID, + "in_routing_list", containsInt64(routingAccountIDs, stickyAccountID), + "is_excluded", isExcluded(stickyAccountID), + "in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(), + "session", shortSessionHash(sessionHash), + ) if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { @@ -1638,6 +1688,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { + slog.Debug("sticky.layer1_5_hit", + "account_id", stickyAccountID, + "session", shortSessionHash(sessionHash), + "result", "slot_acquired", + ) if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } @@ -1788,27 +1843,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 检查账户是否需要清理粘性会话绑定 clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { + slog.Debug("sticky.layer1_5_no_routing_clear", + "account_id", accountID, + "reason", "should_clear_sticky_session", + "session", shortSessionHash(sessionHash), + ) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && - s.isAccountAllowedForPlatform(account, platform, useMixed) && - (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && - s.isAccountSchedulableForQuota(account) && - s.isAccountSchedulableForWindowCost(ctx, account, true) && - s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 + // 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的 + // accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存 + // 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。 + platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed) + modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) + modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) + quotaOK := s.isAccountSchedulableForQuota(account) + windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true) + rpmOK := s.isAccountSchedulableForRPM(ctx, account, true) + schedulable := s.isAccountSchedulableForSelection(account) + + slog.Debug("sticky.layer1_5_no_routing_checks", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "clear_sticky", clearSticky, + "schedulable", schedulable, + "platform_ok", platformOK, + "model_supported", modelSupported, + "model_schedulable", modelSchedulable, + "quota_ok", quotaOK, + "window_cost_ok", windowCostOK, + "rpm_ok", rpmOK, + ) + + if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "session_limit", + "session", shortSessionHash(sessionHash), + ) } else { + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "result", "slot_acquired", + ) if s.cache != nil { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) } return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } + } else { + slog.Debug("sticky.layer1_5_no_routing_slot_busy", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + ) } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) @@ -1817,6 +1910,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 } else { + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", accountID, + "session", shortSessionHash(sessionHash), + "result", "wait_plan", + ) return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ AccountID: accountID, MaxConcurrency: account.Concurrency, @@ -1825,12 +1923,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro }) } } + } else if !clearSticky { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "gate_check_failed", + "session", shortSessionHash(sessionHash), + ) } + } else { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", accountID, + "reason", "account_not_in_map", + "session", shortSessionHash(sessionHash), + ) } } + } else if len(routingAccountIDs) == 0 && sessionHash != "" { + slog.Debug("sticky.layer1_5_no_routing_skip", + "sticky_account_id", stickyAccountID, + "is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(), + "session", shortSessionHash(sessionHash), + "reason", func() string { + if stickyAccountID == 0 { + return "no_sticky_binding" + } + return "sticky_account_excluded" + }(), + ) } // ============ Layer 2: 负载感知选择 ============ + slog.Debug("sticky.layer2_fallback", + "session", shortSessionHash(sessionHash), + "sticky_account_id", stickyAccountID, + "reason", "sticky_not_used_falling_back_to_load_balance", + "total_accounts", len(accounts), + ) candidates := make([]*Account, 0, len(accounts)) for i := range accounts { acc := &accounts[i] @@ -3654,7 +3782,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { - requestedModel = claude.NormalizeModelID(requestedModel) + if account.Type == AccountTypeServiceAccount { + requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel)) + } else { + requestedModel = claude.NormalizeModelID(requestedModel) + } } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) @@ -3674,6 +3806,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( return apiKey, "apikey", nil case AccountTypeBedrock: return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 + case AccountTypeServiceAccount: + if account.Platform != PlatformAnthropic { + return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform) + } + if s.claudeTokenProvider == nil { + return "", "", errors.New("claude token provider not configured") + } + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "service_account", nil default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -3781,16 +3925,6 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { return ParseMetadataUserID(metadataUserID) != nil } -func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { - if IsClaudeCodeClient(ctx) { - return true - } - if parsed == nil || c == nil { - return false - } - return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) -} - // normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), // 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。 // 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 @@ -4153,6 +4287,87 @@ func enforceCacheControlLimit(body []byte) []byte { return body } +// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。 +// 仅修改已经存在的 cache_control,不新增缓存断点。 +func injectAnthropicCacheControlTTL1h(body []byte) []byte { + return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h) +} + +func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte { + if len(body) == 0 || ttl == "" { + return body + } + out := body + var paths []string + addPath := func(path string, value gjson.Result) { + cc := value.Get("cache_control") + if !cc.Exists() || cc.Get("type").String() != "ephemeral" { + return + } + if cc.Get("ttl").String() == ttl { + return + } + paths = append(paths, path+".cache_control.ttl") + } + + if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl { + paths = append(paths, "cache_control.ttl") + } + + system := gjson.GetBytes(body, "system") + if system.IsArray() { + idx := -1 + system.ForEach(func(_, block gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("system.%d", idx), block) + return true + }) + } + + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIdx := -1 + messages.ForEach(func(_, msg gjson.Result) bool { + msgIdx++ + content := msg.Get("content") + if !content.IsArray() { + return true + } + contentIdx := -1 + content.ForEach(func(_, block gjson.Result) bool { + contentIdx++ + addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block) + return true + }) + return true + }) + } + + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + idx := -1 + tools.ForEach(func(_, tool gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("tools.%d", idx), tool) + return true + }) + } + + for _, path := range paths { + if next, err := sjson.SetBytes(out, path, ttl); err == nil { + out = next + } + } + return out +} + +func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool { + if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil { + return false + } + return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() @@ -4286,6 +4501,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + if candidate, matched := account.ResolveMappedModel(reqModel); matched { + mappedModel = candidate + mappingSource = "account" + } else { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel)) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "vertex" + } + } + } if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(reqModel) if normalized != reqModel { @@ -4300,6 +4527,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } + if s.shouldInjectAnthropicCacheTTL1h(ctx, account) { + body = injectAnthropicCacheControlTTL1h(body) + } + // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { @@ -5799,6 +6030,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse( } func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { + if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream) + } + // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { @@ -6010,6 +6245,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } +func (s *GatewayService) buildUpstreamRequestAnthropicVertex( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, + modelID string, + reqStream bool, +) (*http.Request, error) { + vertexBody, err := buildVertexAnthropicRequestBody(body) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, vertexBody) + fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" { + continue + } + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Del("anthropic-version") + setHeaderRaw(req.Header, "authorization", "Bearer "+token) + setHeaderRaw(req.Header, "content-type", "application/json") + + s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{ + "url": req.URL.String(), + "token_type": "service_account", + "model": modelID, + "stream": strconv.FormatBool(reqStream), + }) + + return req, nil +} + // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { @@ -6567,6 +6856,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { return false } +// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。 +// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 +// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection +// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。 +func sanitizeStreamError(err error) string { + if err == nil { + return "" + } + switch { + case errors.Is(err, io.ErrUnexpectedEOF): + return "unexpected EOF" + case errors.Is(err, io.EOF): + return "EOF" + case errors.Is(err, context.Canceled): + return "canceled" + case errors.Is(err, context.DeadlineExceeded): + return "deadline exceeded" + case errors.Is(err, syscall.ECONNRESET): + return "connection reset by peer" + case errors.Is(err, syscall.ECONNABORTED): + return "connection aborted" + case errors.Is(err, syscall.ETIMEDOUT): + return "connection timed out" + case errors.Is(err, syscall.EPIPE): + return "broken pipe" + case errors.Is(err, syscall.ECONNREFUSED): + return "connection refused" + } + var netErr *net.OpError + if errors.As(err, &netErr) { + if netErr.Timeout() { + if netErr.Op != "" { + return netErr.Op + " timeout" + } + return "i/o timeout" + } + if netErr.Op != "" { + return netErr.Op + " network error" + } + } + return "upstream connection error" +} + // ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 // 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} func ExtractUpstreamErrorMessage(body []byte) string { @@ -7004,14 +7336,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } lastDataAt := time.Now() - // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。 + // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}} + // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案, + // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。 errorEventSent := false - sendErrorEvent := func(reason string) { + sendErrorEvent := func(reason, message string) { if errorEventSent { return } errorEventSent = true - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + if message == "" { + message = reason + } + body, err := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": reason, + "message": message, + }, + }) + if err != nil { + // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback + body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message)) + } + _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body) flusher.Flush() } @@ -7088,9 +7437,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } - // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { - overrideTarget := account.GetCacheTTLOverrideTarget() + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { @@ -7171,10 +7520,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") + sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize)) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } - sendErrorEvent("stream_read_error") + // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): + // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 + // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 + // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址, + // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err + // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。 + disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err) + if !c.Writer.Written() { + logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) + body, _ := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": "upstream_disconnected", + "message": disconnectMsg, + }, + }) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + RetryableOnSameAccount: true, + } + } + sendErrorEvent("stream_read_error", disconnectMsg) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line @@ -7233,7 +7604,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } - sendErrorEvent("stream_timeout") + sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval)) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: @@ -7475,6 +7846,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { return true } +func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) { + if account == nil { + return "", false + } + if account.IsCacheTTLOverrideEnabled() { + return account.GetCacheTTLOverrideTarget(), true + } + if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) { + return cacheTTLTarget5m, true + } + return "", false +} + func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -7511,9 +7895,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } - // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { - overrideTarget := account.GetCacheTTLOverrideTarget() + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { @@ -8081,10 +8465,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + // Cache TTL Override: 确保计费时 token 分类与账号设置一致。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { - applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { + applyCacheTTLOverride(&result.Usage, overrideTarget) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index b1584827..ef09a882 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -4,9 +4,12 @@ package service import ( "context" + "errors" "io" + "net" "net/http" "net/http/httptest" + "syscall" "testing" "time" @@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { body := rec.Body.String() require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") } + +// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前: +// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。 +func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{err: io.ErrUnexpectedEOF}, + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + + require.Error(t, err) + require.Nil(t, result, "失败移交场景下不应返回 streamingResult") + + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试") + + // ResponseBody 必须是 Anthropic 标准 error 格式: + // 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖) + // 2) error.type 标记为 upstream_disconnected + extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody) + require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息") + require.Contains(t, extractedMsg, "upstream stream disconnected") + require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`) + require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`) + + // 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定 + require.NotContains(t, rec.Body.String(), "stream_read_error") +} + +// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误: +// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。 +func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"), + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error") + require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult") + + // 不应被错误地包成 failover error + var failoverErr *UpstreamFailoverError + require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover") + + // 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error, + // error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误) + body := rec.Body.String() + require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧") + require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)") + require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error") + require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案") +} + +// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 +// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过 +// failover ResponseBody 或 SSE error 帧返回给客户端。 +func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) { + src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321") + require.NoError(t, err) + dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443") + require.NoError(t, err) + + raw := &net.OpError{ + Op: "read", + Net: "tcp", + Source: src, + Addr: dst, + Err: syscall.ECONNRESET, + } + + // 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过) + require.Contains(t, raw.Error(), "10.0.0.1") + require.Contains(t, raw.Error(), "52.1.2.3") + + got := sanitizeStreamError(raw) + require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP") + require.NotContains(t, got, "54321", "不得泄露源端口") + require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP") + require.NotContains(t, got, "443", "不得泄露上游端口") + require.Equal(t, "connection reset by peer", got) +} + +func TestSanitizeStreamError_KnownErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"}, + {"EOF", io.EOF, "EOF"}, + {"context canceled", context.Canceled, "canceled"}, + {"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"}, + {"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"}, + {"EPIPE", syscall.EPIPE, "broken pipe"}, + {"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"}, + {"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"}, + {"nil 返回空串", nil, ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.want, sanitizeStreamError(tc.err)) + }) + } +} + +// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志 +// 时携带内部地址信息。 +func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321") + dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443") + netErr := &net.OpError{ + Op: "read", + Net: "tcp", + Source: src, + Addr: dst, + Err: syscall.ECONNRESET, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{err: netErr}, + } + + _, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr)) + + body := string(failoverErr.ResponseBody) + require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP") + require.NotContains(t, body, "54321") + require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP") + require.NotContains(t, body, "443") + // 仍然包含可诊断的根因 + require.Contains(t, body, "connection reset by peer") + require.Contains(t, body, "upstream stream disconnected") +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 7a24071b..ea0c0d7d 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont } // Code Assist OAuth tokens often lack AI Studio scopes for models listing. return 3 + case AccountTypeServiceAccount: + // Vertex service accounts use aiplatform.googleapis.com, not the AI Studio + // endpoint (generativelanguage.googleapis.com), so they cannot serve these requests. + return 999 default: return 10 } @@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex originalModel := req.Model mappedModel := req.Model - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(req.Model) } @@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } requestIDHeader = "x-request-id" + case AccountTypeServiceAccount: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream) + if err != nil { + return nil, "", err + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + default: return nil, fmt.Errorf("unsupported account type: %s", account.Type) } @@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. body = ensureGeminiFunctionCallThoughtSignatures(body) mappedModel := originalModel - if account.Type == AccountTypeAPIKey { + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { mappedModel = account.GetMappedModel(originalModel) } @@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } requestIDHeader = "x-request-id" + case AccountTypeServiceAccount: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream) + if err != nil { + return nil, "", err + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + default: return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type) } diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 7add3460..172b9411 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -15,7 +15,7 @@ const ( geminiTokenCacheSkew = 5 * time.Minute ) -// GeminiTokenProvider manages access_token for Gemini OAuth accounts. +// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts. type GeminiTokenProvider struct { accountRepo AccountRepository tokenCache GeminiTokenCache @@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { - return "", errors.New("not a gemini oauth account") + if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) { + return "", errors.New("not a gemini oauth or service account") + } + if account.Type == AccountTypeServiceAccount { + return p.getServiceAccountAccessToken(ctx, account) } cacheKey := GeminiTokenCacheKey(account) @@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } +func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) { + return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account) +} + func GeminiTokenCacheKey(account *Account) string { + if account != nil && account.Type == AccountTypeServiceAccount { + if key, err := parseVertexServiceAccountKey(account); err == nil { + return vertexServiceAccountCacheKey(account, key) + } + } projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return "gemini:" + projectID diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index e765d7e9..b256f1c7 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -53,6 +53,23 @@ const ( codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n" ) +var openAIChatGPTInternalUnsupportedFields = []string{ + "user", + "metadata", + "prompt_cache_retention", + "safety_identifier", + "stream_options", +} + +var openAICodexOAuthUnsupportedFields = append([]string{ + "max_output_tokens", + "max_completion_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", +}, openAIChatGPTInternalUnsupportedFields...) + func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -93,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } } - // Strip parameters unsupported by codex models via the Responses API. - for _, key := range []string{ - "max_output_tokens", - "max_completion_tokens", - "temperature", - "top_p", - "frequency_penalty", - "presence_penalty", - // prompt_cache_retention is a newer Responses API parameter (cache TTL). - // The ChatGPT internal Codex endpoint rejects it with - // "Unsupported parameter: prompt_cache_retention". Defense-in-depth - // for any OAuth path that reaches this transform — the Cursor - // Responses-shape short-circuit in ForwardAsChatCompletions strips - // it earlier too, but we keep this line so other OAuth callers are - // equally protected. - "prompt_cache_retention", - } { + // Strip parameters unsupported by ChatGPT internal Codex endpoint. + for _, key := range openAICodexOAuthUnsupportedFields { if _, ok := reqBody[key]; ok { delete(reqBody, key) result.Modified = true @@ -141,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" { reqBody["tool_choice"] = map[string]any{ "type": "function", - "function": map[string]any{ - "name": name, - }, + "name": name, } } } @@ -219,9 +219,38 @@ func normalizeCodexToolChoice(reqBody map[string]any) bool { return false } choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"])) - if choiceType == "" || codexToolsContainType(reqBody["tools"], choiceType) { + if choiceType == "" { return false } + modified := false + if choiceType == "function" { + name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) + if name == "" { + if function, ok := choiceMap["function"].(map[string]any); ok { + name = strings.TrimSpace(firstNonEmptyString(function["name"])) + } + } + if name == "" { + reqBody["tool_choice"] = "auto" + return true + } + if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name { + choiceMap["name"] = name + modified = true + } + if _, ok := choiceMap["function"]; ok { + delete(choiceMap, "function") + modified = true + } + if !codexToolsContainFunctionName(reqBody["tools"], name) { + reqBody["tool_choice"] = "auto" + return true + } + return modified + } + if codexToolsContainType(reqBody["tools"], choiceType) { + return modified + } reqBody["tool_choice"] = "auto" return true } @@ -243,6 +272,33 @@ func codexToolsContainType(rawTools any, toolType string) bool { return false } +func codexToolsContainFunctionName(rawTools any, name string) bool { + tools, ok := rawTools.([]any) + if !ok || strings.TrimSpace(name) == "" { + return false + } + normalizedName := strings.TrimSpace(name) + for _, rawTool := range tools { + tool, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" { + continue + } + toolName := strings.TrimSpace(firstNonEmptyString(tool["name"])) + if toolName == "" { + if function, ok := tool["function"].(map[string]any); ok { + toolName = strings.TrimSpace(firstNonEmptyString(function["name"])) + } + } + if toolName == normalizedName { + return true + } + } + return false +} + func normalizeCodexToolRoleMessages(input []any) ([]any, bool) { if len(input) == 0 { return input, false @@ -853,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any { } typ, _ := m["type"].(string) + // chatgpt.com codex backend (OAuth path) does not persist reasoning + // items because applyCodexOAuthTransform forces store=false. Any rs_* + // reference replayed in input is guaranteed to 404 upstream + // ("Item with id 'rs_...' not found"). Drop reasoning items entirely. + if typ == "reasoning" { + continue + } + // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id; // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。 fixCallIDPrefix := func(id string) string { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 75f5c55c..87bb7162 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,6 +1,8 @@ package service import ( + "fmt" + "strings" "testing" "github.com/stretchr/testify/require" @@ -249,6 +251,44 @@ func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) { require.Equal(t, "custom", choice["type"]) } +func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{"name": "shell"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "function", choice["type"]) + require.Equal(t, "shell", choice["name"]) + require.NotContains(t, choice, "function") +} + +func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "function", "name": "shell"}, + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{"name": "missing"}, + }, + } + + applyCodexOAuthTransform(reqBody, true, false) + + require.Equal(t, "auto", reqBody["tool_choice"]) +} + func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.4", @@ -1048,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) { "prompt_cache_retention must be stripped before forwarding to Codex upstream") } +func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "user": "user_123", + "metadata": map[string]any{"trace_id": "abc"}, + "prompt_cache_retention": "24h", + "safety_identifier": "sid", + "stream_options": map[string]any{"include_usage": true}, + "input": []any{ + map[string]any{"role": "user", "content": "hi"}, + }, + } + + result := applyCodexOAuthTransform(reqBody, true, false) + + require.True(t, result.Modified) + for _, field := range openAIChatGPTInternalUnsupportedFields { + require.NotContains(t, reqBody, field) + } +} + func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.1", @@ -1094,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) { }) } } + +func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) { + // Reasoning items in input[] reference rs_* IDs that were emitted by + // chatgpt.com under store=false (forced by applyCodexOAuthTransform). + // They are never persisted upstream, so forwarding them produces a + // guaranteed 404 ("Item with id 'rs_...' not found"). Drop them + // regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957. + + build := func() []any { + return []any{ + map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"}, + map[string]any{ + "type": "reasoning", + "id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344", + "summary": []any{}, + }, + map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"}, + map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"}, + } + } + + for _, preserve := range []bool{true, false} { + preserve := preserve + t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) { + filtered := filterCodexInput(build(), preserve) + + for _, raw := range filtered { + item, ok := raw.(map[string]any) + require.True(t, ok) + require.NotEqual(t, "reasoning", item["type"], + "reasoning items must be dropped from input on the OAuth path") + if id, ok := item["id"].(string); ok { + require.False(t, strings.HasPrefix(id, "rs_"), + "no item carrying an rs_* id should survive the filter") + } + } + + // Sanity check: the non-reasoning items should still be present. + gotTypes := make(map[string]int) + for _, raw := range filtered { + item, ok := raw.(map[string]any) + require.True(t, ok) + typ, ok := item["type"].(string) + require.True(t, ok) + gotTypes[typ]++ + } + require.Equal(t, 1, gotTypes["message"]) + require.Equal(t, 1, gotTypes["function_call"]) + require.Equal(t, 1, gotTypes["function_call_output"]) + require.Equal(t, 0, gotTypes["reasoning"]) + }) + } +} diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go new file mode 100644 index 00000000..b52da614 --- /dev/null +++ b/backend/internal/service/openai_fast_policy_test.go @@ -0,0 +1,286 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIFastPolicyRepoStub struct { + values map[string]string +} + +func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} + +func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + return nil +} + +func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService { + t.Helper() + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + if settings != nil { + raw, err := json.Marshal(settings) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw) + } + return &OpenAIGatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + +func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast + // 是用户级开关,与 model 正交。 + // gpt-5.5 + priority → filter + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-5.5-turbo → filter + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-4 + priority → filter(默认策略覆盖所有模型) + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-5.5 + flex → pass (tier doesn't match) + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex) + require.Equal(t, BetaPolicyActionPass, action) + + // empty tier → pass + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "") + require.Equal(t, BetaPolicyActionPass, action) +} + +func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "fast mode is not allowed", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionBlock, action) + require.Equal(t, "fast mode is not allowed", msg) +} + +func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierAny, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeOAuth, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + + // OAuth account → rule matches + oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth} + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // API Key account → rule skipped → pass + apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionPass, action) +} + +func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // gpt-5.5 fast → service_tier stripped + body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // Client sending "fast" (alias for priority) also filtered + body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除 + body = []byte(`{"model":"gpt-4","service_tier":"priority"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // No service_tier → no-op + body = []byte(`{"model":"gpt-5.5"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后 +// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被 +// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。 +func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + for _, tier := range []string{"auto", "default", "scale"} { + body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err, "tier %q should pass without error", tier) + require.Contains(t, string(updated), `"service_tier":"`+tier+`"`, + "tier %q should be preserved in body under default rule", tier) + } + + // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配) + for _, tier := range []string{"auto", "default", "scale"} { + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier) + require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier) + } +} + +// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置 +// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会 +// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。 +func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierAny, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} { + body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`, + "tier %q should be stripped under ServiceTier=all + filter rule", tier) + } +} + +// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离 +// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段; +// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在 +// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。 +func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // normalize 阶段会将未知值剥离 + require.Nil(t, normalizeOpenAIServiceTier("xxx")) + + // applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变 + // (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离) + body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "fast mode is blocked for gpt-5.5", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.Error(t, err) + var blocked *OpenAIFastBlockedError + require.True(t, errors.As(err, &blocked)) + require.Contains(t, blocked.Message, "fast mode is blocked") + require.Equal(t, string(body), string(updated)) // body not mutated on block +} + +func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) { + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + svc := NewSettingService(repo, &config.Config{}) + + // Invalid action rejected + err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: "bogus", + Scope: BetaPolicyScopeAll, + }}, + }) + require.Error(t, err) + + // Invalid service_tier rejected + err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: "turbo", + Action: BetaPolicyActionPass, + Scope: BetaPolicyScopeAll, + }}, + }) + require.Error(t, err) + + // Valid settings persisted + err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }}, + }) + require.NoError(t, err) + + got, err := svc.GetOpenAIFastPolicySettings(context.Background()) + require.NoError(t, err) + require.Len(t, got.Rules, 1) + require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier) +} diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go new file mode 100644 index 00000000..3316a242 --- /dev/null +++ b/backend/internal/service/openai_fast_policy_ws_test.go @@ -0,0 +1,1018 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate --- + +func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier") + // Other fields preserved. + require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String()) + require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String()) +} + +func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Verbatim "fast" → normalized to "priority" → matches default rule → filter. + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`) + + // Mixed-case + whitespace variant should also normalize and filter. + frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`) +} + +func TestWSResponseCreate_FlexPassThrough(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Default policy targets priority only; flex is left untouched. + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy") +} + +func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "ws fast blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.NotNil(t, blocked) + require.Equal(t, "ws fast blocked", blocked.Message) + // On block, payload returned unchanged so caller can inspect / log it. + require.Equal(t, string(frame), string(updated)) +} + +func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation") +} + +func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"*"}, + FallbackAction: BetaPolicyActionFilter, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // response.cancel happens to carry a service_tier-shaped field — must not be touched. + frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated)) +} + +// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the +// helper used to treat empty type as response.create, which risked stripping +// fields from malformed / unknown client events. After the A1 fix only a +// strict "response.create" match triggers policy. +func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"*"}, + FallbackAction: BetaPolicyActionFilter, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Frame with no "type" field: must pass through completely unchanged + // even with a service_tier-shaped field present. + frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through") + + // Explicit empty string also passes through. + frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated)) +} + +// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1 +// regression: the rendered Realtime error event must carry a non-empty +// event_id (so clients can correlate the rejection) and a stable error.code +// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error +// JSON body emitted by writeOpenAIFastPolicyBlockedResponse. +func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) { + bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"}) + require.NotNil(t, bytes) + + require.Equal(t, "error", gjson.GetBytes(bytes, "type").String()) + require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String()) + require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String()) + require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String()) + + eventID := gjson.GetBytes(bytes, "event_id").String() + require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs") + require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_ Realtime convention; got %q", eventID) + + // Sanity check: two consecutive events get distinct IDs. + other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"}) + otherID := gjson.GetBytes(other, "event_id").String() + require.NotEqual(t, eventID, otherID, "event_id must be random per-event") +} + +// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns +// nil for a nil error (defensive guard for callers that always invoke it). +func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) { + require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil)) +} + +// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback --- + +// fakePassthroughFrameConn replays a fixed sequence of client frames into the +// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts +// for write-side assertions (none expected in the D5 test, since the wrapper +// only filters reads). +type fakePassthroughFrameConn struct { + reads [][]byte + idx int + writes [][]byte + closeOnce bool +} + +func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if f.idx >= len(f.reads) { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + payload := f.reads[f.idx] + f.idx++ + return coderws.MessageText, payload, nil +} + +func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + cp := append([]byte(nil), payload...) + f.writes = append(f.writes, cp) + return nil +} + +func (f *fakePassthroughFrameConn) Close() error { + f.closeOnce = true + return nil +} + +// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于 +// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时 +// fallback 路径无法被观察到)。 +func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings { + return &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"}, + FallbackAction: BetaPolicyActionPass, + }}, + } +} + +// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is +// the D5 regression: in passthrough mode a follow-up response.create frame +// without a "model" field must still hit the policy via the session-level +// model captured from the first frame. Without the fallback an empty model +// would miss a model whitelist and silently leak service_tier=priority +// through to the upstream. +func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) { + // 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel + // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致, + // 不能用来覆盖此回归)。 + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Simulate the passthrough adapter capturing model from the first frame. + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + require.Equal(t, "gpt-5.5", capturedSessionModel) + + // Follow-up frame deliberately omits "model" — Realtime allows this. + followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`) + + inner := &fakePassthroughFrameConn{ + reads: [][]byte{followupFrame}, + } + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + // Read the follow-up frame through the wrapper. The policy MUST still + // trigger filter (gpt-5.5 + priority → filter), so the service_tier + // field is gone by the time the relay sees it. + _, payload, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.NotContains(t, string(payload), `"service_tier"`, + "D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5") + require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String()) +} + +// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the +// inverse: when the wrapper has NO capturedSessionModel fallback (model is +// empty per-frame and no fallback is wired up), the policy fails to match +// the model whitelist and the frame leaks through unchanged. This documents +// exactly the leak the D5 fix prevents. +func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) { + // 同样使用带 whitelist 的策略以观察 leak。 + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`) + inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}} + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + // NO fallback — emulate the pre-fix behavior. + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + _, payload, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + // Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept. + require.Contains(t, string(payload), `"service_tier"`, + "sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing") +} + +// --- Ingress end-to-end test (filter path) --- + +// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the +// real ProxyResponsesWebSocketFromClient ingress session pipeline against a +// captureConn upstream and asserts that a client frame with service_tier=fast +// is normalized + filtered out before being written upstream. This is the +// integration flavour of TestWSResponseCreate_FilterStripsServiceTier. +func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings()) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + settingService: NewSettingService(repo, cfg), + } + + account := &Account{ + ID: 901, + Name: "openai-ws-filter", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + _, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`))) + cancelWrite() + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create") + upstream := captureConn.writes[0] + _, hasServiceTier := upstream["service_tier"] + require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)") + require.Equal(t, "response.create", upstream["type"]) + require.Equal(t, "gpt-5.5", upstream["model"]) +} + +// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the +// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It +// asserts that with a custom block rule, the client receives a Realtime-style +// error event AND the upstream FrameConn never receives the offending frame. +func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + // No events queued; the upstream should never get written to anyway. + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + blockSettings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "ws priority blocked for testing", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + raw, err := json.Marshal(blockSettings) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + settingService: NewSettingService(repo, cfg), + } + + account := &Account{ + ID: 902, + Name: "openai-ws-block", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + _, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + // Mirror the production handler (openai_gateway_handler.go:1325-1328): + // when the proxy returns an OpenAIWSClientCloseError, surface its + // status code to the client via a graceful close handshake. Without + // this the deferred CloseNow() above would tear down the TCP + // connection without sending a close frame, and the C3 timing + // assertion (next read returns CloseStatus=1008) would see EOF + // instead. + var closeErr *OpenAIWSClientCloseError + if errors.As(proxyErr, &closeErr) { + _ = conn.Close(closeErr.StatusCode(), closeErr.Reason()) + } + serverErrCh <- proxyErr + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`))) + cancelWrite() + + // C3 timing assertion: the FIRST frame the client reads must be the + // error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is + // synchronous (writeFrame Flushes the bufio writer at write.go:307-311 + // before returning) and the close handshake re-acquires the same + // writeFrameMu, so this ordering is enforced by the library itself; this + // assertion guards against future refactors that might break it. + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr, "first read must succeed and return the error event before any close frame") + require.Equal(t, "error", gjson.GetBytes(event, "type").String()) + require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String()) + // B1 regression: event_id + error.code must be populated. + require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String()) + require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate") + require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing") + + // Next read must surface the close frame (as a CloseError). This + // asserts the [error event, close] ordering — i.e. the close did NOT + // race ahead of the data frame. + readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second) + _, _, secondReadErr := clientConn.Read(readCtx2) + cancelRead2() + require.Error(t, secondReadErr, "after the error event the connection must surface a close") + require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr), + "close status must be PolicyViolation; got %v", secondReadErr) + + select { + case serverErr := <-serverErrCh: + // Server returns an OpenAIWSClientCloseError — handler closes the WS; + // here we just assert it surfaced as the typed close error. + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError,得到 %T: %v", serverErr, serverErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress 关闭超时") + } + + // Critical: the offending frame must NEVER reach the upstream. + // captureDialer.DialCount may legitimately be 0 or 1 depending on whether + // the lease was acquired before policy fired; either way, no writes. + require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create") +} + +// --- HTTP-side gap-filling tests (already covered by existing tests but +// requested to be split out explicitly) --- + +// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that +// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule +// action is "block", and that the body is left untouched. The caller (chat +// completions / messages handlers) inspects this typed error and skips the +// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and +// openai_gateway_messages.go:149. +func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "priority blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.Error(t, err) + var blocked *OpenAIFastBlockedError + require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request") + require.Equal(t, "priority blocked", blocked.Message) + require.Equal(t, string(body), string(updated), "block must not mutate body") +} + +// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies +// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode +// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60) +// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has +// no service_tier. We exercise the same internal pipeline (Anthropic→Responses +// + BetaFastMode + policy) without spinning up a real upstream HTTP server. +func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50). + anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`) + var anthropicReq apicompat.AnthropicRequest + require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq)) + responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) + require.NoError(t, err) + + // Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61). + headers := http.Header{} + headers.Set("anthropic-beta", claude.BetaFastMode) + require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode)) + responsesReq.ServiceTier = "priority" + responsesReq.Model = "gpt-5.5" + + // Step 3: marshal & apply fast policy (mirrors line 78 + 149). + responsesBody, err := json.Marshal(responsesReq) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置:beta 翻译应当注入 priority") + + upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody) + require.NoError(t, policyErr) + + // Step 4: assert that policy filtered the field before the upstream HTTP request. + require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier") +} + +// --- Fix1: passthrough capturedSessionModel must follow session.update --- + +// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the +// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under +// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends +// response.create without "model". Without the session.update sniffing the +// follow-up frame would fall back to the stale gpt-4o capture and pass — the +// fix updates capturedSessionModel from session.* events so the fallback now +// resolves to gpt-5.5 and the policy filters service_tier. +func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Frame 1: response.create with whitelist-miss model — under default + // rule fallback=pass, service_tier stays. + first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`) + // Frame 2: session.update rotates the session model to gpt-5.5. + rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`) + // Frame 3: response.create WITHOUT model — must inherit gpt-5.5. + followup := []byte(`{"type":"response.create","service_tier":"priority"}`) + + inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}} + + // Replicate the production wiring in openai_ws_v2_passthrough_adapter.go + // so capturedSessionModel state is shared across frames. + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first) + require.Equal(t, "gpt-4o", capturedSessionModel) + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + // Frame 1: gpt-4o miss whitelist → pass (service_tier preserved). + _, payload1, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier") + + // Frame 2: session.update — not response.create, untouched, but its + // side effect updates capturedSessionModel to gpt-5.5. + _, payload2, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim") + require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel") + + // Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter. + _, payload3, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.NotContains(t, string(payload3), `"service_tier"`, + "fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter") +} + +// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative +// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only +// client→upstream session.update frames rotate the captured model; +// server→client events (session.created) and unrelated frames must not. +func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) { + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // session.created is a server→client event in the OpenAI Realtime + // protocol — clients never send it, so this filter (which only runs on + // the client→upstream direction) must ignore it even if it appears. + created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created)) + + // Non-session.* frames must NOT trigger rotation. + notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession)) + + // Missing session.model returns empty — caller keeps the old captured value. + noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel)) +} + +// --- Fix2: native /responses normalize "fast" → "priority" on pass --- + +// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2 +// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody +// returned the body unchanged so a raw "fast" alias would leak to the +// upstream OpenAI API (which does not accept "fast"). The fix normalizes +// "fast" → "priority" on pass too. +func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) { + // Use a policy that deliberately misses gpt-4 so the action is pass. + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority". + body := []byte(`{"model":"gpt-4","service_tier":"fast"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), + "fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug") + + // Already-canonical "priority" on pass: zero mutation (byte-equal). + body = []byte(`{"model":"gpt-4","service_tier":"priority"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) + + // Mixed-case alias → normalized. + body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String()) + + // Unrecognized tier → still no-op (not normalized, since normTier == ""). + body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +// --- Fix3: passthrough billing must reflect post-filter service_tier --- + +// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The +// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts +// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy +// has rewritten it, so a filter hit causes billing to report nil (default +// tier) instead of the user-requested "priority". This test pins the +// contract those two helpers must uphold for the adapter's billing path. +func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + + // Pre-filter sanity: extracting from the raw frame would (incorrectly, + // pre-fix) report "priority" — this is the very thing the adapter + // must NOT do anymore. + pre := extractOpenAIServiceTierFromBody(raw) + require.NotNil(t, pre) + require.Equal(t, "priority", *pre, + "sanity: raw first frame carries priority that pre-fix billing would have reported") + + // Apply policy filter (default rule: gpt-5.5 + priority → filter). + filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(filtered), `"service_tier"`) + + // Post-filter: extracting from the rewritten frame returns nil. This + // is the value the adapter now passes to OpenAIForwardResult.ServiceTier, + // so billing records "default" instead of "priority". + post := extractOpenAIServiceTierFromBody(filtered) + require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority") + + // And the byte-level invariant the adapter relies on: filtering an + // already-filtered frame is a no-op (idempotent), so re-running the + // policy doesn't accidentally re-introduce the field. + again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered) + require.NoError(t, err) + require.Nil(t, blocked2) + require.Equal(t, string(filtered), string(again), + "policy is idempotent: filtering an already-filtered frame leaves bytes unchanged") +} + +// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap +// flagged in the review: when a client sends service_tier as a non-string +// (number, null, object, etc.) the policy must NOT panic and must NOT +// pretend the field was filtered. Behavior: skip policy entirely (treat as +// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's +// type-assertion `reqBody["service_tier"].(string); ok` guard. +func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Number — gjson .String() coerces to "1" which is not a recognized + // tier alias; normalize returns "" → policy no-ops. + cases := [][]byte{ + []byte(`{"model":"gpt-5.5","service_tier":1}`), + []byte(`{"model":"gpt-5.5","service_tier":null}`), + []byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`), + []byte(`{"model":"gpt-5.5","service_tier":["priority"]}`), + []byte(`{"model":"gpt-5.5","service_tier":true}`), + } + for _, body := range cases { + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err, "non-string service_tier must not error: %s", string(body)) + require.Equal(t, string(body), string(updated), + "non-string service_tier must pass through unchanged: %s", string(body)) + } + + // Same guard for the WS response.create entry. + for _, body := range cases { + frame := body + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame)) + require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame)) + require.Equal(t, string(frame), string(updated), + "non-string service_tier ws frame must pass through unchanged: %s", string(frame)) + } +} + +// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the +// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS +// allows the client to ship a different service_tier on each response.create +// frame (per-response field, see codex-rs/core/src/client.rs +// build_responses_request which re-fills the field on every request). Before +// the fix the adapter only captured service_tier from firstClientMessage so +// turn 2/3 billing was wrong. After the fix the filter closure refreshes an +// atomic.Pointer[string] on every successful response.create frame. +// +// This test pins the four legs of the semantic contract: +// - turn 1: service_tier=priority hits the default whitelist filter, so +// after filter the upstream sees no tier → billing is nil. +// - turn 2: service_tier=flex passes (default rule targets priority only), +// billing should now reflect "flex". +// - turn 3: response.create without any service_tier — the upstream will +// treat it as default; we choose to mirror that and overwrite billing +// to nil rather than carry over "flex" from turn 2. +// - non-response.create frame (response.cancel here) carrying a stray +// service_tier-shaped field must NOT clobber the billing pointer. +func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go + // proxyResponsesWebSocketV2Passthrough) so this test fails if the + // production code drops the per-frame Store. + var requestServiceTierPtr atomic.Pointer[string] + capturedSessionModel := "" + filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + } + + // First-frame initialization mirrors the adapter: extract from the + // post-filter payload so a filter-on-first-frame zeroes billing too. + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame) + require.NoError(t, firstErr) + require.Nil(t, firstBlocked) + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut)) + capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + require.Nil(t, requestServiceTierPtr.Load(), + "turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier") + + // Turn 2: client switches to flex, should pass and update billing. + turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`) + out2, blocked2, err2 := filter(coderws.MessageText, turn2) + require.NoError(t, err2) + require.Nil(t, blocked2) + require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched") + tier2 := requestServiceTierPtr.Load() + require.NotNil(t, tier2, "turn 2: billing must update to reflect flex") + require.Equal(t, "flex", *tier2) + + // A non-response.create frame with a stray service_tier-shaped field + // must NOT overwrite the billing pointer (those frames don't carry + // per-response service_tier in the Realtime spec). + cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`) + _, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame) + require.NoError(t, errCancel) + require.Nil(t, blockedCancel) + tierAfterCancel := requestServiceTierPtr.Load() + require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil") + require.Equal(t, "flex", *tierAfterCancel, + "non-response.create frames must not update billing tier even if they carry a service_tier-shaped field") + + // Turn 3: response.create without any service_tier. We deliberately + // overwrite billing back to nil so it tracks what the upstream actually + // sees on this turn (default tier). + turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`) + out3, blocked3, err3 := filter(coderws.MessageText, turn3) + require.NoError(t, err3) + require.Nil(t, blocked3) + require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate") + require.Nil(t, requestServiceTierPtr.Load(), + "turn 3: response.create without service_tier overwrites billing to nil to match upstream default") +} + +// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the +// "block keeps previous" semantic: when policy returns block on a +// response.create frame, that frame is never sent upstream, so billing tier +// must keep the previous turn's value rather than getting silently zeroed. +func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) { + blockSettings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, blockSettings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + var requestServiceTierPtr atomic.Pointer[string] + flexValue := "flex" + requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex + + filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + } + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + _, blocked, err := filter(coderws.MessageText, frame) + require.NoError(t, err) + require.NotNil(t, blocked, "policy must block this frame") + + tier := requestServiceTierPtr.Load() + require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil") + require.Equal(t, "flex", *tier, + "blocked frame is never sent upstream; billing must retain the previous turn's tier") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 663066a3..5822ae4c 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } } + // 4b. Apply OpenAI fast policy (may filter service_tier or block the request). + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) + } + return nil, policyErr + } + responsesBody = updatedBody + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index a00fb71c..6846e03a 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) { normalizeResponsesRequestServiceTier(req) require.Equal(t, "flex", req.ServiceTier) + // OpenAI 官方合法 tier 应被透传保留。 + req.ServiceTier = "auto" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "auto", req.ServiceTier) + req.ServiceTier = "default" normalizeResponsesRequestServiceTier(req) + require.Equal(t, "default", req.ServiceTier) + + req.ServiceTier = "scale" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "scale", req.ServiceTier) + + // 真未知值仍被剥离。 + req.ServiceTier = "turbo" + normalizeResponsesRequestServiceTier(req) require.Empty(t, req.ServiceTier) } @@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Equal(t, "flex", tier) require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) + // OpenAI 官方 tier 直接保留在 body 中(透传上游)。 + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`)) + require.NoError(t, err) + require.Equal(t, "auto", tier) + require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String()) + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) require.NoError(t, err) + require.Equal(t, "default", tier) + require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`)) + require.NoError(t, err) + require.Equal(t, "scale", tier) + require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String()) + + // 真未知值才会被删除。 + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`)) + require.NoError(t, err) require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 2a0a72eb..4e0ebb2e 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } } + // 4c. Apply OpenAI fast policy (may filter service_tier or block the request). + // Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed + // on the body-level service_tier field (priority/flex). + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message) + } + return nil, policyErr + } + responsesBody = updatedBody + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 9665c4c8..47ff4e3b 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, nil, nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, @@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) { require.Equal(t, "priority", *got) }) - t.Run("default ignored", func(t *testing.T) { - require.Nil(t, normalizeOpenAIServiceTier("default")) + t.Run("openai official tiers preserved", func(t *testing.T) { + // OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄 + // 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex, + // 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。 + for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} { + got := normalizeOpenAIServiceTier(tier) + require.NotNil(t, got, "tier %q should not be normalized to nil", tier) + require.Equal(t, tier, *got) + } }) t.Run("invalid ignored", func(t *testing.T) { require.Nil(t, normalizeOpenAIServiceTier("turbo")) + require.Nil(t, normalizeOpenAIServiceTier("xxx")) }) } func TestExtractOpenAIServiceTier(t *testing.T) { require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) + require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"})) + require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"})) + require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"})) require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) require.Nil(t, extractOpenAIServiceTier(nil)) } @@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) { func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) - require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`))) + require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`))) require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 379ebe0b..ed69730c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -334,6 +334,7 @@ type OpenAIGatewayService struct { resolver *ModelPricingResolver channelService *ChannelService balanceNotifyService *BalanceNotifyService + settingService *SettingService openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -372,6 +373,7 @@ func NewOpenAIGatewayService( resolver *ModelPricingResolver, channelService *ChannelService, balanceNotifyService *BalanceNotifyService, + settingService *SettingService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -402,6 +404,7 @@ func NewOpenAIGatewayService( resolver: resolver, channelService: channelService, balanceNotifyService: balanceNotifyService, + settingService: settingService, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -1125,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str return sessionID } +func explicitOpenAISessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + +// GenerateExplicitSessionHash generates a sticky-session hash only from explicit +// client session signals. It intentionally skips content-derived fallback and is +// used by stateless endpoints such as /v1/images. +func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string { + sessionID := explicitOpenAISessionID(c, body) + if sessionID == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + // GenerateSessionHash generates a sticky-session hash for OpenAI requests. // // Priority: @@ -1137,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) return "" } - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } + sessionID := explicitOpenAISessionID(c, body) if sessionID == "" && len(body) > 0 { sessionID = deriveOpenAIContentSessionSeed(body) } @@ -2287,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } + // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤): + // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略 + // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽 + // fast 时在此生效。 + // + // 注意: + // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel + + // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与 + // chat-completions / messages 入口保持一致,避免不同入口因为模型 + // 维度不同而出现 whitelist 命中差异。 + // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body, + // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat- + // completions 入口由 normalizeResponsesBodyServiceTier 完成同一 + // 行为,这里手工实现等效逻辑。 + if rawTier, ok := reqBody["service_tier"].(string); ok { + if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" { + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel) + } + blocked := &OpenAIFastBlockedError{Message: msg} + writeOpenAIFastPolicyBlockedResponse(c, blocked) + return nil, blocked + case BetaPolicyActionFilter: + delete(reqBody, "service_tier") + bodyModified = true + disablePatch() + default: + // pass:若客户端传的是别名 "fast",归一化为 "priority" + // 后写回 body,确保上游收到的是其能识别的规范值。 + if normTier != rawTier { + reqBody["service_tier"] = normTier + bodyModified = true + markPatchSet("service_tier", normTier) + } + } + } + } + // Re-serialize body only if modified if bodyModified { serializedByPatch := false @@ -2735,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( body = sanitizedBody } + // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier). + // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 + + // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。 + // 这样可以与 chat-completions / messages / native /responses 入口的 + // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有 + // model 字段时退回 reqModel。 + policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) + if policyModel == "" { + policyModel = reqModel + } + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeOpenAIFastPolicyBlockedResponse(c, blocked) + } + return nil, policyErr + } + body = updatedBody + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -4841,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) { } normalized := []byte(`{}`) - for _, field := range []string{"model", "input", "instructions", "previous_response_id"} { + // Keep the current Codex /compact schema while still dropping request-scoped + // fields such as prompt_cache_key, store, and stream. + for _, field := range []string{ + "model", + "input", + "instructions", + "tools", + "parallel_tool_calls", + "reasoning", + "text", + "previous_response_id", + } { value := gjson.GetBytes(body, field) if !value.Exists() { continue @@ -5454,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p } // normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: -// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false +// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数 +// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) { if len(body) == 0 { return body, false, nil @@ -5463,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo normalized := body changed := false + for _, field := range openAIChatGPTInternalUnsupportedFields { + if value := gjson.GetBytes(normalized, field); !value.Exists() { + continue + } + next, err := sjson.DeleteBytes(normalized, field) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err) + } + normalized = next + changed = true + } + if compact { if store := gjson.GetBytes(normalized, "store"); store.Exists() { next, err := sjson.DeleteBytes(normalized, "store") @@ -5567,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string { if value == "fast" { value = "priority" } + // 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。 + // 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs), + // 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。 + // 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。 switch value { - case "priority", "flex": + case "priority", "flex", "auto", "default", "scale": return &value default: return nil } } +// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast +// policy (action=block). Mirrors BetaBlockedError on the Claude side. +type OpenAIFastBlockedError struct { + Message string +} + +func (e *OpenAIFastBlockedError) Error() string { return e.Message } + +// evaluateOpenAIFastPolicy returns the action and error message that should be +// applied for a request with the given account/model/service_tier. When the +// policy service is unavailable or no rule matches, it returns +// (BetaPolicyActionPass, "") so callers can short-circuit safely. +// +// Matching rules: +// - Scope filters by account type (all / oauth / apikey / bedrock) +// - ServiceTier must be empty (= any), "all", or equal the normalized tier +// - ModelWhitelist narrows the rule to specific models; FallbackAction +// handles the non-matching case (default: pass) +// +// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit): +// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同 +// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。 +// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段, +// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier +// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist +// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而 +// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。 +func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) { + if s == nil || s.settingService == nil { + return BetaPolicyActionPass, "" + } + tier := strings.ToLower(strings.TrimSpace(serviceTier)) + if tier == "" { + return BetaPolicyActionPass, "" + } + settings := openAIFastPolicySettingsFromContext(ctx) + if settings == nil { + fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx) + if err != nil || fetched == nil { + return BetaPolicyActionPass, "" + } + settings = fetched + } + return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier) +} + +// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so +// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting +// the settingService on every frame. See WSSession entry and +// openAIFastPolicySettingsFromContext for the caching glue. +func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) { + if settings == nil { + return BetaPolicyActionPass, "" + } + isOAuth := account != nil && account.IsOAuth() + isBedrock := account != nil && account.IsBedrock() + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier)) + if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier { + continue + } + eff := BetaPolicyRule{ + Action: rule.Action, + ErrorMessage: rule.ErrorMessage, + ModelWhitelist: rule.ModelWhitelist, + FallbackAction: rule.FallbackAction, + FallbackErrorMessage: rule.FallbackErrorMessage, + } + return resolveRuleAction(eff, model) + } + return BetaPolicyActionPass, "" +} + +// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存 +// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。 +// +// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是 +// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude +// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员 +// 可以通过踢断 session 强制刷新。 +type openAIFastPolicyCtxKeyType struct{} + +var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{} + +// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx +// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。 +func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context { + if ctx == nil || settings == nil { + return ctx + } + return context.WithValue(ctx, openAIFastPolicyCtxKey, settings) +} + +func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings { + if ctx == nil { + return nil + } + if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok { + return v + } + return nil +} + +// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request +// body. When action=filter it removes the service_tier field; when +// action=block it returns (body, *OpenAIFastBlockedError). On pass it +// normalizes the service_tier value (e.g. client alias "fast" → "priority"), +// rewriting the body so the upstream receives a slug it recognizes. +// +// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本 +// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化 +// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等 +// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样 +// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。 +func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) { + if len(body) == 0 { + return body, nil + } + rawTier := gjson.GetBytes(body, "service_tier").String() + if rawTier == "" { + return body, nil + } + normTier := normalizedOpenAIServiceTierValue(rawTier) + if normTier == "" { + return body, nil + } + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model) + } + return body, &OpenAIFastBlockedError{Message: msg} + case BetaPolicyActionFilter: + trimmed, err := sjson.DeleteBytes(body, "service_tier") + if err != nil { + return body, fmt.Errorf("strip service_tier from body: %w", err) + } + return trimmed, nil + default: + // pass:把别名(如 "fast")写回为规范值("priority")。 + if normTier == rawTier { + return body, nil + } + updated, err := sjson.SetBytes(body, "service_tier", normTier) + if err != nil { + return body, fmt.Errorf("normalize service_tier on pass: %w", err) + } + return updated, nil + } +} + +// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a +// request blocked by the OpenAI fast policy. +func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) { + if c == nil || err == nil { + return + } + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": err.Message, + }, + }) +} + +// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy +// against a single client→upstream WebSocket frame whose top-level +// "type"=="response.create". It mirrors the HTTP-side +// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses +// WS payload: +// +// - pass: returns frame unchanged (newBytes == frame, blocked == nil) +// - filter: returns a copy with top-level service_tier removed +// - block: returns (frame, *OpenAIFastBlockedError) +// +// Only frames whose "type" field strictly equals "response.create" are +// inspected/mutated. Any other frame type — including the empty string — +// passes through untouched. The OpenAI Realtime client-event spec requires +// "type" to be set, so an empty type is treated as a malformed frame we do +// not police; the upstream is the source of truth for rejecting it. +// +// service_tier lives at the top level of response.create — same as the +// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 + +// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at +// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need +// to inspect / strip the top-level field; there is no nested form in the +// schema today. +// +// The caller is responsible for choosing the upstream model passed in — +// this helper does not re-derive it. +func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate( + ctx context.Context, + account *Account, + model string, + frame []byte, +) ([]byte, *OpenAIFastBlockedError, error) { + if len(frame) == 0 { + return frame, nil, nil + } + if !gjson.ValidBytes(frame) { + return frame, nil, nil + } + frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String()) + // Strict match: only response.create is policy-checked. Empty / other + // types pass through untouched so we never accidentally strip fields + // from response.cancel, conversation.item.create, or any future + // client-event the spec adds. The Realtime spec requires "type" on + // every client event, so an empty type is malformed input — let the + // upstream reject it rather than guessing at our layer. + if frameType != "response.create" { + return frame, nil, nil + } + rawTier := gjson.GetBytes(frame, "service_tier").String() + if rawTier == "" { + return frame, nil, nil + } + normTier := normalizedOpenAIServiceTierValue(rawTier) + if normTier == "" { + return frame, nil, nil + } + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model) + } + return frame, &OpenAIFastBlockedError{Message: msg}, nil + case BetaPolicyActionFilter: + trimmed, err := sjson.DeleteBytes(frame, "service_tier") + if err != nil { + return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err) + } + return trimmed, nil, nil + default: + return frame, nil, nil + } +} + +// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a +// server-emitted error event. Matches the loose "evt_" convention used +// by upstream Realtime servers; the exact value is not load-bearing and is +// only required for client-side log correlation. We reuse the existing +// google/uuid dependency rather than pulling a new one. +func newOpenAIFastPolicyWSEventID() string { + id, err := uuid.NewRandom() + if err != nil { + // Extremely unlikely; fall back to a fixed prefix so the field is + // still non-empty and the schema stays self-consistent. + return "evt_openai_fast_policy" + } + // Strip dashes so it visually matches "evt_" rather than UUID v4 + // canonical form, mirroring what real Realtime traces look like. + return "evt_" + strings.ReplaceAll(id.String(), "-", "") +} + +// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses +// style "error" event payload for a request blocked by the OpenAI fast +// policy. The shape mirrors Realtime error events as observed in upstream +// traces and per the spec's server "error" event: +// +// { +// "event_id": "evt_", +// "type": "error", +// "error": { +// "type": "invalid_request_error", +// "code": "policy_violation", +// "message": "..." +// } +// } +// +// event_id lets clients correlate the rejection in their logs; "code" gives +// programmatic clients a stable identifier (HTTP-side equivalent is the +// 403 permission_error JSON body). +func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte { + if err == nil { + return nil + } + eventID := newOpenAIFastPolicyWSEventID() + payload, mErr := json.Marshal(map[string]any{ + "event_id": eventID, + "type": "error", + "error": map[string]any{ + "type": "invalid_request_error", + "code": "policy_violation", + "message": err.Message, + }, + }) + if mErr != nil { + // Fallback to a minimal hand-rolled payload; Marshal of the literal + // shape above should never fail in practice. + return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`) + } + return payload +} + func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) { if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) { return body, false, nil diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 9c08d5f2..5d1c6fc6 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) } +func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := &OpenAIGatewayService{} + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`) + + t.Run("stateless image body stays unstuck", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + require.Empty(t, svc.GenerateExplicitSessionHash(c, body)) + require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context())) + }) + + t.Run("prompt_cache_key is explicit", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`)) + require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) + }) + + t.Run("header overrides body", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + c.Request.Header.Set("session_id", "header-session") + + got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`)) + require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got) + }) +} + func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -1756,6 +1791,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) { } } +func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) { + body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`) + + normalized, changed, err := normalizeOpenAICompactRequestBody(body) + + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String()) + require.True(t, gjson.GetBytes(normalized, "tools").Exists()) + require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool()) + require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String()) + require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String()) + require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String()) + require.False(t, gjson.GetBytes(normalized, "store").Exists()) + require.False(t, gjson.GetBytes(normalized, "stream").Exists()) + require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists()) +} + func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 200547d4..47113d4d 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -258,6 +259,25 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) } +func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) { + require.Equal(t, + "https://image-upstream.example/v1/images/generations", + buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint), + ) + require.Equal(t, + "https://image-upstream.example/v1/images/edits", + buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint), + ) + require.Equal(t, + "https://image-upstream.example/v1/images/generations", + buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint), + ) + require.Equal(t, + "https://image-upstream.example/v1/images/generations", + buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint), + ) +} + type openAIImageTestSSEEvent struct { Name string Data string @@ -371,6 +391,124 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_apikey"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 6, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "gpt-image-2", result.Model) + require.Equal(t, "gpt-image-2", result.UpstreamModel) + + upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder) + require.True(t, ok) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + +func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background")) + imagePart, err := writer.CreateFormFile("image", "source.png") + require.NoError(t, err) + _, err = imagePart.Write([]byte("png-image-content")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"req_img_edit_apikey"}, + }, + Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + + account := &Account{ + ID: 7, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + "base_url": "https://image-upstream.example/v1/", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + + upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder) + require.True(t, ok) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data") + require.Contains(t, string(upstream.lastBody), `name="model"`) + require.Contains(t, string(upstream.lastBody), "gpt-image-2") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) +} + func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) diff --git a/backend/internal/service/openai_passthrough_normalization_test.go b/backend/internal/service/openai_passthrough_normalization_test.go new file mode 100644 index 00000000..492ff610 --- /dev/null +++ b/backend/internal/service/openai_passthrough_normalization_test.go @@ -0,0 +1,33 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`) + + normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false) + require.NoError(t, err) + require.True(t, changed) + for _, field := range openAIChatGPTInternalUnsupportedFields { + require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field) + } + require.True(t, gjson.GetBytes(normalized, "stream").Bool()) + require.False(t, gjson.GetBytes(normalized, "store").Bool()) +} + +func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`) + + normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true) + require.NoError(t, err) + require.True(t, changed) + require.False(t, gjson.GetBytes(normalized, "user").Exists()) + require.False(t, gjson.GetBytes(normalized, "metadata").Exists()) + require.False(t, gjson.GetBytes(normalized, "stream").Exists()) + require.False(t, gjson.GetBytes(normalized, "store").Exists()) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 8c0222e2..d1386b1b 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1366,16 +1366,27 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string func shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled bool, turn int, - hasFunctionCallOutput bool, + signals ToolContinuationSignals, currentPreviousResponseID string, expectedPreviousResponseID string, ) bool { - if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput { return false } if strings.TrimSpace(currentPreviousResponseID) != "" { return false } + if signals.HasFunctionCallOutputMissingCallID { + return false + } + // If the client already sent the actual tool-call context, treat this as + // a full replay / self-contained continuation payload rather than + // downgrading it into an inferred delta continuation. item_reference alone + // is not enough on the store=false WS path: it still needs a valid prior + // response anchor so upstream can resolve the referenced function_call. + if signals.HasToolCallContext { + return false + } return strings.TrimSpace(expectedPreviousResponseID) != "" } @@ -2366,6 +2377,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return errors.New("token is empty") } + // 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session + // 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧 + // 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。 + if s.settingService != nil { + if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil { + ctx = withOpenAIFastPolicyContext(ctx, settings) + } + } + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled ingressMode := OpenAIWSIngressModeCtxPool @@ -2524,6 +2544,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( normalized = next } + // Apply OpenAI Fast Policy on the response.create frame using the same + // evaluator/normalize/scope rules as the HTTP entrypoints. This is the + // single integration point for all WS ingress turns (first + follow-up + // frames flow through here). + // + // Model fallback: parseClientPayload above rejects any frame whose + // "model" field is missing (line ~2493-2500), so by the time we + // reach this point upstreamModel is always derived from a non-empty + // per-frame model. The capturedSessionModel fallback used in the + // passthrough adapter is therefore not needed in this path. + policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized) + if policyErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr) + } + if blocked != nil { + // Send a Realtime-style error event to the client first, then + // signal the handler to close the connection with PolicyViolation. + // We intentionally do NOT forward this frame upstream. + // + // coder/websocket@v1.8.14 Conn.Write is synchronous and flushes + // the underlying bufio writer before returning (write.go:42 → + // 307-311), and the subsequent close handshake re-acquires the + // same writeFrameMu, so the error event is guaranteed to reach + // the kernel send buffer before any close frame is queued. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes != nil { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancel() + } + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + blocked.Message, + blocked, + ) + } + normalized = policyApplied + return openAIWSClientPayload{ payloadRaw: normalized, rawForHash: trimmed, @@ -3132,13 +3190,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( skipBeforeTurn = false currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) - hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + toolSignals := ToolContinuationSignals{ + HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(), + } + if toolSignals.HasFunctionCallOutput { + var currentReqBody map[string]any + if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil { + toolSignals = AnalyzeToolContinuationSignals(currentReqBody) + } + } + hasFunctionCallOutput := toolSignals.HasFunctionCallOutput // store=false + function_call_output 场景必须有续链锚点。 // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 if shouldInferIngressFunctionCallOutputPreviousResponseID( storeDisabled, turn, - hasFunctionCallOutput, + toolSignals, currentPreviousResponseID, expectedPrev, ) { diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 6bf9a9ff..30fd4142 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-tool-context", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{captureConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-item-reference", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID") +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { gin.SetMode(gin.TestMode) prevPreflightPingIdle := openAIWSIngressPreflightPingIdle diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index ff35cb01..c735f50a 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { name string storeDisabled bool turn int - hasFunctionCallOutput bool + signals ToolContinuationSignals currentPreviousResponse string expectedPrevious string want bool }{ { - name: "infer_when_all_conditions_match", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: true, + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: true, }, { - name: "skip_when_store_enabled", - storeDisabled: false, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_on_first_turn", - storeDisabled: true, - turn: 1, - hasFunctionCallOutput: true, - expectedPrevious: "resp_1", - want: false, + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "resp_1", + want: false, }, { - name: "skip_without_function_call_output", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: false, - expectedPrevious: "resp_1", - want: false, + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{}, + expectedPrevious: "resp_1", + want: false, }, { name: "skip_when_request_already_has_previous_response_id", storeDisabled: true, turn: 2, - hasFunctionCallOutput: true, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, currentPreviousResponse: "resp_client", expectedPrevious: "resp_1", want: false, }, { - name: "skip_when_last_turn_response_id_missing", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: "", - want: false, + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: "", + want: false, }, { - name: "trim_whitespace_before_judgement", - storeDisabled: true, - turn: 2, - hasFunctionCallOutput: true, - expectedPrevious: " resp_2 ", - want: true, + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true}, + expectedPrevious: " resp_2 ", + want: true, + }, + { + name: "skip_when_tool_call_context_already_present", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true}, + expectedPrevious: "resp_2", + want: false, + }, + { + name: "infer_when_only_item_reference_covers_call_ids", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true}, + expectedPrevious: "resp_2", + want: true, + }, + { + name: "skip_when_function_call_output_missing_call_id", + storeDisabled: true, + turn: 2, + signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true}, + expectedPrevious: "resp_2", + want: false, }, } @@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { got := shouldInferIngressFunctionCallOutputPreviousResponseID( tt.storeDisabled, tt.turn, - tt.hasFunctionCallOutput, + tt.signals, tt.currentPreviousResponse, tt.expectedPrevious, ) diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 66e5db93..f3936de1 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index cda2e351..3dbb199a 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct { conn *coderws.Conn } +// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs +// every client→upstream frame through the OpenAI Fast Policy. It is the +// passthrough-relay equivalent of the parseClientPayload integration in the +// ingress session path. filter returns: +// - newPayload, nil, nil: forward the (possibly mutated) payload +// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error +// event via onBlock and surfaces a transport-level error so the relay +// stops reading from the client. +// - _, _, err: a transport error other than block. +type openAIWSPolicyEnforcingFrameConn struct { + inner openaiwsv2.FrameConn + filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) + onBlock func(blocked *OpenAIFastBlockedError) +} + +var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil) + +func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.inner == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + msgType, payload, err := c.inner.ReadFrame(ctx) + if err != nil { + return msgType, payload, err + } + if c.filter == nil { + return msgType, payload, nil + } + updated, blocked, filterErr := c.filter(msgType, payload) + if filterErr != nil { + return msgType, payload, filterErr + } + if blocked != nil { + if c.onBlock != nil { + c.onBlock(blocked) + } + return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked) + } + return msgType, updated, nil +} + +func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.inner == nil { + return errOpenAIWSConnClosed + } + return c.inner.WriteFrame(ctx, msgType, payload) +} + +func (c *openAIWSPolicyEnforcingFrameConn) Close() error { + if c == nil || c.inner == nil { + return nil + } + return c.inner.Close() +} + +// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective +// model name that should be passed to evaluateOpenAIFastPolicy for a single +// passthrough WS frame. Mirrors the HTTP-side normalization +// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path +// matches model whitelists identically. +func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string { + if account == nil || len(payload) == 0 { + return "" + } + original := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if original == "" { + return "" + } + return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) +} + +// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model +// derived from a session.update frame's session.model field. Returns "" when +// the frame is not a session.update event or carries no session.model. Used +// by the per-frame policy filter (client→upstream direction) to keep +// capturedSessionModel in sync with the session-level model the client may +// rotate mid-session. +// +// Realtime / Responses WS lets the client change the session model after +// the WS handshake via: +// +// {"type":"session.update","session":{"model":"gpt-5.5", ...}} +// +// If we only capture the model from the very first frame, a client can ship +// gpt-4o on the first response.create (whitelisted as pass), then +// session.update to gpt-5.5, then send response.create without "model" so +// the per-frame resolver returns "" and the stale capturedSessionModel falls +// back to gpt-4o — defeating the gpt-5.5 fast-policy filter. +func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string { + if account == nil || len(payload) == 0 { + return "" + } + frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + if frameType != "session.update" { + return "" + } + original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String()) + if original == "" { + return "" + } + return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) +} + const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) @@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( return errors.New("token is empty") } requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) - requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage) requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) logOpenAIWSV2Passthrough( "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", @@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( len(firstClientMessage), ) + // Apply OpenAI Fast Policy on the first response.create frame. Subsequent + // frames are filtered via a wrapping FrameConn below so every client→ + // upstream frame goes through the same policy evaluator/normalize/scope as + // HTTP entrypoints. + // + // We capture the session-level model from the first frame here so the + // per-frame filter (below) can fall back to it when a follow-up frame + // omits "model" — Realtime clients are allowed to send response.create + // without re-stating the model, in which case the upstream uses the model + // negotiated at session.update time. Without this fallback, an empty + // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be + // silently passed through, defeating the policy on every frame after + // the first. + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) + updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) + if policyErr != nil { + return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) + } + if blocked != nil { + // coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires + // writeFrameMu, writes the entire frame, and Flushes the underlying + // bufio writer before returning (write.go:42 → write.go:307-311). + // The subsequent close handshake re-acquires the same writeFrameMu + // to send the close frame, so the error event is guaranteed to + // reach the kernel send buffer before any close frame is queued. + // No explicit flush hop is required here. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes != nil { + writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancelWrite() + } + return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked) + } + firstClientMessage = updatedFirst + + // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter + // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当 + // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的 + // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody)) + // 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。 + // + // 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在 + // 同一连接的不同 response.create 帧上发送不同 service_tier(参考 + // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 + // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream + // goroutine)和 OnTurnComplete / final result(runUpstreamToClient + // goroutine)之间同步当前 turn 的 service_tier。 + // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型, + // 可直接 Store/Load 而无需额外封装。 + var requestServiceTierPtr atomic.Pointer[string] + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage)) + wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return fmt.Errorf("build ws url: %w", err) @@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( } completedTurns := atomic.Int32{} + policyClientConn := &openAIWSPolicyEnforcingFrameConn{ + inner: &openAIWSClientFrameConn{conn: clientConn}, + // 注意线程安全:filter 仅在 runClientToUpstream 这一条 + // goroutine 中被调用(passthrough_relay.go: ReadFrame loop), + // capturedSessionModel 的读写都发生在该 goroutine 内,因此无需 + // 加锁/原子化。 + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + // 在评估策略前先刷新 capturedSessionModel:客户端可能通过 + // session.update 修改 session-level model(Realtime / + // Responses WS 协议允许),如果不刷新就会出现 + // "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5 + // → 不带 model 的 response.create fallback 到 gpt-4o" 的 + // 绕过路径。这里只看 session.update 事件中的 session.model + // 字段,response.create 自己的 model 仍然由其本帧字段决定。 + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + // Per-frame model first; if the client omits "model" on a + // follow-up frame (legal in Realtime), fall back to the + // session-level model captured from the first frame so the + // model whitelist still resolves. An empty model would miss + // any whitelist and silently fall back to pass. + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) + // 多轮 passthrough billing:仅在成功(non-block / non-err) + // 的 response.create 帧上更新 requestServiceTierPtr,使用 + // filter 处理后的 payload,与首帧 policy-after-extract 语义 + // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 + // - 非 response.create 帧(response.cancel / + // conversation.item.create / session.update 等)不携带 + // per-response service_tier,不应覆盖前一轮值。 + // - blocked != nil:该帧不会发送上游,billing tier 应保持 + // 上一轮值。 + // - policyErr != nil:异常路径,保持上一轮值。 + // - 不带 service_tier 的 response.create 会让 + // extractOpenAIServiceTierFromBody 返回 nil;这里有意 + // 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传 + // service_tier 时按 default 处理,billing 应如实反映。 + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + }, + onBlock: func(blocked *OpenAIFastBlockedError) { + // See note above on Conn.Write being synchronous w.r.t. flush; + // no explicit flush is required to ensure the error event lands + // before the close frame. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes == nil { + return + } + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancel() + }, + } relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ Ctx: ctx, - ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + ClientConn: policyClientConn, UpstreamConn: upstreamFrameConn, FirstClientMessage: firstClientMessage, Options: openaiwsv2.RelayOptions{ @@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, - ServiceTier: requestServiceTier, + ServiceTier: requestServiceTierPtr.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), @@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, - ServiceTier: requestServiceTier, + ServiceTier: requestServiceTierPtr.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 08a10a02..44ec1ad1 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -184,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string { ) } +// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。 +// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据 +// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true +// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天 +// +// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE": +// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成 +// - 无 WAL 写入、无后续 VACUUM 压力 +// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略 +func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) { + if days < 0 { + return time.Time{}, false, false + } + if days == 0 { + return time.Time{}, true, true + } + return now.AddDate(0, 0, -days), false, true +} + func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) { out := opsCleanupDeletedCounts{} if s == nil || s.db == nil || s.cfg == nil { @@ -194,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet now := time.Now().UTC() - // Error-like tables: error logs / retry attempts / alert events. - if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false) + // runOne 把"truncate? cutoff? batched delete?"封装到一处, + // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。 + runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) { + if truncate { + return truncateOpsTable(ctx, s.db, table) + } + return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate) + } + + // Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits. + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false) if err != nil { return out, err } out.errorLogs = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false) if err != nil { return out, err } out.retryAttempts = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false) if err != nil { return out, err } out.alertEvents = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false) if err != nil { return out, err } out.systemLogs = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) + n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false) if err != nil { return out, err } @@ -229,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet } // Minute-level metrics snapshots. - if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false) + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false) if err != nil { return out, err } @@ -239,15 +265,14 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet } // Pre-aggregation tables (hourly/daily). - if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 { - cutoff := now.AddDate(0, 0, -days) - n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false) + if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok { + n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false) if err != nil { return out, err } out.hourlyPreagg = n - n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true) + n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true) if err != nil { return out, err } @@ -303,7 +328,7 @@ WHERE id IN (SELECT id FROM batch) res, err := db.ExecContext(ctx, q, cutoff, batchSize) if err != nil { // If ops tables aren't present yet (partial deployments), treat as no-op. - if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") { + if isMissingRelationError(err) { return total, nil } return total, err @@ -320,6 +345,46 @@ WHERE id IN (SELECT id FROM batch) return total, nil } +// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。 +// +// 与 deleteOldRowsByID 的差异: +// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义 +// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力 +// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略 +// +// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。 +func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) { + if db == nil { + return 0, nil + } + var count int64 + if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("count %s: %w", table, err) + } + if count == 0 { + return 0, nil + } + if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil { + if isMissingRelationError(err) { + return 0, nil + } + return 0, fmt.Errorf("truncate %s: %w", table, err) + } + return count, nil +} + +// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。 +func isMissingRelationError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + return strings.Contains(s, "does not exist") && strings.Contains(s, "relation") +} + func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) { if s == nil { return nil, false diff --git a/backend/internal/service/ops_cleanup_service_test.go b/backend/internal/service/ops_cleanup_service_test.go new file mode 100644 index 00000000..86657d27 --- /dev/null +++ b/backend/internal/service/ops_cleanup_service_test.go @@ -0,0 +1,64 @@ +package service + +import ( + "testing" + "time" +) + +func TestOpsCleanupPlan(t *testing.T) { + now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + + cases := []struct { + name string + days int + wantOK bool + wantTruncate bool + wantCutoff time.Time + }{ + {name: "negative skips", days: -1, wantOK: false}, + {name: "zero truncates", days: 0, wantOK: true, wantTruncate: true}, + {name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cutoff, truncate, ok := opsCleanupPlan(now, tc.days) + if ok != tc.wantOK { + t.Fatalf("ok = %v, want %v", ok, tc.wantOK) + } + if !ok { + return + } + if truncate != tc.wantTruncate { + t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate) + } + if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) { + t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff) + } + }) + } +} + +func TestIsMissingRelationError(t *testing.T) { + cases := []struct { + name string + err error + want bool + }{ + {name: "nil is not missing", err: nil, want: false}, + {name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true}, + {name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true}, + {name: "non-matching error", err: fakeErr("connection refused"), want: false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := isMissingRelationError(tc.err); got != tc.want { + t.Fatalf("got %v, want %v", got, tc.want) + } + }) + } +} + +type fakeErr string + +func (e fakeErr) Error() string { return string(e) } diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index 5871166c..ecc3a94b 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) { if cfg.DataRetention.CleanupSchedule == "" { cfg.DataRetention.CleanupSchedule = "0 2 * * *" } - if cfg.DataRetention.ErrorLogRetentionDays <= 0 { + // 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留; + // 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。 + if cfg.DataRetention.ErrorLogRetentionDays < 0 { cfg.DataRetention.ErrorLogRetentionDays = 30 } - if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 { + if cfg.DataRetention.MinuteMetricsRetentionDays < 0 { cfg.DataRetention.MinuteMetricsRetentionDays = 30 } - if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 { + if cfg.DataRetention.HourlyMetricsRetentionDays < 0 { cfg.DataRetention.HourlyMetricsRetentionDays = 30 } // Normalize auto refresh interval (default 30 seconds) @@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error { if cfg == nil { return errors.New("invalid config") } - if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 { - return errors.New("error_log_retention_days must be between 1 and 365") + // 保留天数:0 表示每次清理全部,1-365 表示按天数保留。 + if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 { + return errors.New("error_log_retention_days must be between 0 and 365") } - if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 { - return errors.New("minute_metrics_retention_days must be between 1 and 365") + if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 { + return errors.New("minute_metrics_retention_days must be between 0 and 365") } - if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 { - return errors.New("hourly_metrics_retention_days must be between 1 and 365") + if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 { + return errors.New("hourly_metrics_retention_days must be between 0 and 365") } if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 { return errors.New("auto_refresh_interval_seconds must be between 15 and 300") diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index c6167447..5df69aea 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -269,7 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e switch action { case redeemActionSkipCompleted: - s.applyAffiliateRebateForOrder(ctx, o) + if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { + return err + } // Code already created and redeemed — just mark completed return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") case redeemActionCreate: @@ -283,7 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } - s.applyAffiliateRebateForOrder(ctx, o) + if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { + return err + } return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") } @@ -361,12 +365,12 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action return c > 0 } -func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) { +func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error { if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 { - return + return nil } if s.affiliateService == nil { - return + return nil } tx, err := s.entClient.Tx(ctx) @@ -374,7 +378,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("begin affiliate rebate tx: %v", err), }) - return + return fmt.Errorf("begin affiliate rebate tx: %w", err) } defer func() { _ = tx.Rollback() }() @@ -384,10 +388,10 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("claim affiliate rebate audit: %w", err) } if !claimed { - return + return nil } rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount) @@ -395,7 +399,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("accrue affiliate rebate: %w", err) } if rebateAmount <= 0 { @@ -406,14 +410,15 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("update affiliate rebate skipped audit: %w", err) } if err := tx.Commit(); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), }) + return fmt.Errorf("commit affiliate rebate tx: %w", err) } - return + return nil } if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{ @@ -423,14 +428,16 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), }) - return + return fmt.Errorf("update affiliate rebate applied audit: %w", err) } if err := tx.Commit(); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err), }) + return fmt.Errorf("commit affiliate rebate tx: %w", err) } + return nil } func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) { @@ -444,11 +451,11 @@ func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, clien }) rows, err := client.QueryContext(ctx, ` INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) -SELECT $1, 'AFFILIATE_REBATE_APPLIED', $2, 'system', NOW() +SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW() WHERE NOT EXISTS ( SELECT 1 FROM payment_audit_logs - WHERE order_id = $1 + WHERE order_id = $1::text AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') ) ON CONFLICT (order_id, action) DO NOTHING diff --git a/backend/internal/service/scheduler_cache.go b/backend/internal/service/scheduler_cache.go index f36135e0..f9794c82 100644 --- a/backend/internal/service/scheduler_cache.go +++ b/backend/internal/service/scheduler_cache.go @@ -59,6 +59,8 @@ type SchedulerCache interface { UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error // TryLockBucket 尝试获取分桶重建锁。 TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) + // UnlockBucket 释放分桶重建锁。 + UnlockBucket(ctx context.Context, bucket SchedulerBucket) error // ListBuckets 返回已注册的分桶集合。 ListBuckets(ctx context.Context) ([]SchedulerBucket, error) // GetOutboxWatermark 读取 outbox 水位。 diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go index 5c0b289b..0b32c2ad 100644 --- a/backend/internal/service/scheduler_snapshot_hydration_test.go +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched return true, nil } +func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error { + return nil +} + func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) { return nil, nil } diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 62b6993d..a68cdf0c 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch if !ok { return nil } + defer func() { + _ = s.cache.UnlockBucket(ctx, bucket) + }() rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f871ee85..2bae686a 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -82,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second // cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL) type cachedGatewayForwardingSettings struct { - fingerprintUnification bool - metadataPassthrough bool - cchSigning bool - expiresAt int64 // unix nano + fingerprintUnification bool + metadataPassthrough bool + cchSigning bool + anthropicCacheTTL1hInjection bool + expiresAt int64 // unix nano } var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings @@ -1175,6 +1176,24 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) + if settings.AffiliateRebateFreezeHours < 0 { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault + } + if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax + } + updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours) + if settings.AffiliateRebateDurationDays < 0 { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault + } + if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax + } + updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays) + if settings.AffiliateRebatePerInviteeCap < 0 { + settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault + } + updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64) updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { @@ -1227,6 +1246,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection) updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) @@ -1287,10 +1307,11 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { }) gatewayForwardingSF.Forget("gateway_forwarding") gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ @@ -1397,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { return false } -// GetGatewayForwardingSettings returns cached gateway forwarding settings. -// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. -// Returns (fingerprintUnification, metadataPassthrough, cchSigning). -func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) { +type gatewayForwardingSettingsResult struct { + fp, mp, cch, cacheTTL1h bool +} + +func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + } } } - type gwfResult struct { - fp, mp, cch bool - } val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + }, nil } } dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) @@ -1421,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing SettingKeyEnableFingerprintUnification, SettingKeyEnableMetadataPassthrough, SettingKeyEnableCCHSigning, + SettingKeyEnableAnthropicCacheTTL1hInjection, }) if err != nil { slog.Warn("failed to get gateway forwarding settings", "error", err) gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: true, - metadataPassthrough: false, - cchSigning: false, - expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), + fingerprintUnification: true, + metadataPassthrough: false, + cchSigning: false, + anthropicCacheTTL1hInjection: false, + expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), }) - return gwfResult{true, false, false}, nil + return gatewayForwardingSettingsResult{fp: true}, nil } fp := true if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { @@ -1438,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing } mp := values[SettingKeyEnableMetadataPassthrough] == "true" cch := values[SettingKeyEnableCCHSigning] == "true" + cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: fp, - metadataPassthrough: mp, - cchSigning: cch, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + fingerprintUnification: fp, + metadataPassthrough: mp, + cchSigning: cch, + anthropicCacheTTL1hInjection: cacheTTL1h, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) - return gwfResult{fp, mp, cch}, nil + return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil }) - if r, ok := val.(gwfResult); ok { - return r.fp, r.mp, r.cch + if r, ok := val.(gatewayForwardingSettingsResult); ok { + return r } - return true, false, false // fail-open defaults + return gatewayForwardingSettingsResult{fp: true} +} + +// GetGatewayForwardingSettings returns cached gateway forwarding settings. +// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. +// Returns (fingerprintUnification, metadataPassthrough, cchSigning). +func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) { + result := s.getGatewayForwardingSettingsCached(ctx) + return result.fp, result.mp, result.cch +} + +// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。 +func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool { + return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h } // IsEmailVerifyEnabled 检查是否开启邮件验证 @@ -1512,6 +1558,54 @@ func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) floa return clampAffiliateRebateRate(rate) } +// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。 +// 返回 0 表示不冻结(向后兼容)。 +func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours) + if err != nil { + return AffiliateRebateFreezeHoursDefault + } + hours, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || hours < 0 { + return AffiliateRebateFreezeHoursDefault + } + if hours > AffiliateRebateFreezeHoursMax { + return AffiliateRebateFreezeHoursMax + } + return hours +} + +// GetAffiliateRebateDurationDays 返回返利有效期(天)。 +// 返回 0 表示永久有效。 +func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays) + if err != nil { + return AffiliateRebateDurationDaysDefault + } + days, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil || days < 0 { + return AffiliateRebateDurationDaysDefault + } + if days > AffiliateRebateDurationDaysMax { + return AffiliateRebateDurationDaysMax + } + return days +} + +// GetAffiliateRebatePerInviteeCap 返回单人返利上限。 +// 返回 0 表示无上限。 +func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap) + if err != nil { + return AffiliateRebatePerInviteeCapDefault + } + cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64) + if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) { + return AffiliateRebatePerInviteeCapDefault + } + return cap +} + // IsPasswordResetEnabled 检查是否启用密码重置功能 // 要求:必须同时开启邮件验证 func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { @@ -1755,6 +1849,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), + SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault), + SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault), + SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64), SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", @@ -1811,12 +1908,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", - SettingPaymentVisibleMethodAlipaySource: "", - SettingPaymentVisibleMethodWxpaySource: "", - SettingPaymentVisibleMethodAlipayEnabled: "false", - SettingPaymentVisibleMethodWxpayEnabled: "false", - openAIAdvancedSchedulerSettingKey: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyEnableAnthropicCacheTTL1hInjection: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -1890,6 +1988,21 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.AffiliateRebateRate = AffiliateRebateRateDefault } + if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 { + if freezeHours > AffiliateRebateFreezeHoursMax { + freezeHours = AffiliateRebateFreezeHoursMax + } + result.AffiliateRebateFreezeHours = freezeHours + } + if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 { + if durationDays > AffiliateRebateDurationDaysMax { + durationDays = AffiliateRebateDurationDaysMax + } + result.AffiliateRebateDurationDays = durationDays + } + if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 { + result.AffiliateRebatePerInviteeCap = perInviteeCap + } result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 @@ -2144,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" + result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" // Web search emulation: quick enabled check from the JSON config if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { @@ -3175,6 +3289,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data)) } +// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置 +func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultOpenAIFastPolicySettings(), nil + } + return nil, fmt.Errorf("get openai fast policy settings: %w", err) + } + if value == "" { + return DefaultOpenAIFastPolicySettings(), nil + } + + var settings OpenAIFastPolicySettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + // JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配 + // 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常 + // 行为时定位到 settings 表里的脏数据。 + slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults", + "error", err, + "key", SettingKeyOpenAIFastPolicySettings) + return DefaultOpenAIFastPolicySettings(), nil + } + + return &settings, nil +} + +// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置 +func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + validActions := map[string]bool{ + BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, + } + validScopes := map[string]bool{ + BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true, + } + validTiers := map[string]bool{ + OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true, + } + + for i, rule := range settings.Rules { + tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier)) + if tier == "" { + tier = OpenAIFastTierAny + } + if !validTiers[tier] { + return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier) + } + settings.Rules[i].ServiceTier = tier + if !validActions[rule.Action] { + return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action) + } + if !validScopes[rule.Scope] { + return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope) + } + for j, pattern := range rule.ModelWhitelist { + trimmed := strings.TrimSpace(pattern) + if trimmed == "" { + return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j) + } + settings.Rules[i].ModelWhitelist[j] = trimmed + } + if rule.FallbackAction != "" && !validActions[rule.FallbackAction] { + return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction) + } + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal openai fast policy settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data)) +} + // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 70d8efc3..41c01cca 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -104,12 +104,15 @@ type SystemSettings struct { CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints - DefaultConcurrency int - DefaultBalance float64 - AffiliateEnabled bool - AffiliateRebateRate float64 - DefaultUserRPMLimit int - DefaultSubscriptions []DefaultSubscriptionSetting + DefaultConcurrency int + DefaultBalance float64 + AffiliateEnabled bool + AffiliateRebateRate float64 + AffiliateRebateFreezeHours int + AffiliateRebateDurationDays int + AffiliateRebatePerInviteeCap float64 + DefaultUserRPMLimit int + DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -146,9 +149,10 @@ type SystemSettings struct { BackendModeEnabled bool // Gateway forwarding behavior - EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) - EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) - EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) + EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) + EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false) // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 @@ -402,3 +406,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings { }, } } + +// OpenAI Fast Policy 策略常量 +// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别: +// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式 +// - "flex":低优先级模式 +// - 省略:normal 默认 +// +// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从 +// anthropic-beta header 换成 body 的 service_tier 字段。 +const ( + OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier + OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority) + OpenAIFastTierFlex = "flex" // 仅匹配 flex +) + +// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则 +type OpenAIFastPolicyRule struct { + ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all" + Action string `json:"action"` // "pass" | "filter" | "block" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" + ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) + ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效) + FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式 + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效) +} + +// OpenAIFastPolicySettings OpenAI fast 策略配置 +type OpenAIFastPolicySettings struct { + Rules []OpenAIFastPolicyRule `json:"rules"` +} + +// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。 +// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段, +// 让上游按 normal 优先级处理。 +// +// 为什么 ModelWhitelist 为空(=对所有模型生效): +// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使 +// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁 +// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。 +// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定 +// 模型,可在 admin UI 中显式配置 model_whitelist。 +func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings { + return &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{ + { + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{}, + FallbackAction: BetaPolicyActionPass, + }, + }, + } +} diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go new file mode 100644 index 00000000..4430cf81 --- /dev/null +++ b/backend/internal/service/vertex_service_account.go @@ -0,0 +1,345 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + vertexDefaultLocation = "us-central1" + vertexDefaultTokenURL = "https://oauth2.googleapis.com/token" + vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" + vertexServiceAccountCacheSkew = 5 * time.Minute + vertexLockWaitTime = 200 * time.Millisecond + vertexAnthropicVersion = "vertex-2023-10-16" +) + +var ( + vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`) + vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`) + vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`) +) + +type vertexServiceAccountKey struct { + Type string `json:"type"` + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + TokenURI string `json:"token_uri"` +} + +type vertexTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + Error string `json:"error"` + ErrorDesc string `json:"error_description"` +} + +func (a *Account) IsVertexServiceAccount() bool { + return a != nil && a.Type == AccountTypeServiceAccount +} + +func (a *Account) VertexProjectID() string { + if a == nil { + return "" + } + if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" { + return v + } + key, err := parseVertexServiceAccountKey(a) + if err == nil { + return strings.TrimSpace(key.ProjectID) + } + return "" +} + +func (a *Account) VertexLocation(model string) string { + if a == nil { + return vertexDefaultLocation + } + if model != "" && a.Credentials != nil { + if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok { + if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" { + return strings.TrimSpace(loc) + } + } + } + if v := strings.TrimSpace(a.GetCredential("location")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" { + return v + } + return vertexDefaultLocation +} + +func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) { + if account == nil || account.Credentials == nil { + return nil, errors.New("service account credentials not configured") + } + + if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" { + return parseVertexServiceAccountJSON([]byte(raw)) + } + if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" { + return parseVertexServiceAccountJSON([]byte(raw)) + } + if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok { + b, _ := json.Marshal(nested) + return parseVertexServiceAccountJSON(b) + } + if nested, ok := account.Credentials["service_account"].(map[string]any); ok { + b, _ := json.Marshal(nested) + return parseVertexServiceAccountJSON(b) + } + return nil, errors.New("service_account_json not found in credentials") +} + +func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) { + var key vertexServiceAccountKey + if err := json.Unmarshal(raw, &key); err != nil { + return nil, fmt.Errorf("invalid service account json: %w", err) + } + if strings.TrimSpace(key.ClientEmail) == "" { + return nil, errors.New("service account json missing client_email") + } + if strings.TrimSpace(key.PrivateKey) == "" { + return nil, errors.New("service account json missing private_key") + } + if strings.TrimSpace(key.ProjectID) == "" { + return nil, errors.New("service account json missing project_id") + } + // Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri. + key.TokenURI = vertexDefaultTokenURL + return &key, nil +} + +func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string { + fingerprint := "" + if key != nil { + sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID)) + fingerprint = hex.EncodeToString(sum[:8]) + } + if fingerprint == "" && account != nil { + fingerprint = fmt.Sprintf("account:%d", account.ID) + } + return "vertex:service_account:" + fingerprint +} + +// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account, +// using the shared cache and distributed lock to avoid redundant exchanges. +func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) { + key, err := parseVertexServiceAccountKey(account) + if err != nil { + return "", err + } + cacheKey := vertexServiceAccountCacheKey(account, key) + + if cache != nil { + if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + locked := false + if cache != nil { + var lockErr error + locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if lockErr == nil && locked { + defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }() + } else if lockErr != nil { + slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr) + } else { + time.Sleep(vertexLockWaitTime) + if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + } + + accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key) + if err != nil { + return "", err + } + if cache != nil { + _ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + return accessToken, nil +} + +func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) { + now := time.Now() + claims := jwt.MapClaims{ + "iss": key.ClientEmail, + "scope": vertexCloudPlatformScope, + "aud": key.TokenURI, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + if strings.TrimSpace(key.PrivateKeyID) != "" { + token.Header["kid"] = key.PrivateKeyID + } + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey)) + if err != nil { + return "", 0, fmt.Errorf("parse service account private key: %w", err) + } + assertion, err := token.SignedString(privateKey) + if err != nil { + return "", 0, fmt.Errorf("sign service account assertion: %w", err) + } + + values := url.Values{} + values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + values.Set("assertion", assertion) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode())) + if err != nil { + return "", 0, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", 0, fmt.Errorf("service account token request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + var parsed vertexTokenResponse + _ = json.Unmarshal(body, &parsed) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg := strings.TrimSpace(parsed.ErrorDesc) + if msg == "" { + msg = strings.TrimSpace(parsed.Error) + } + if msg == "" { + msg = string(bytes.TrimSpace(body)) + } + return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg) + } + if strings.TrimSpace(parsed.AccessToken) == "" { + return "", 0, errors.New("service account token response missing access_token") + } + ttl := time.Duration(parsed.ExpiresIn) * time.Second + if ttl <= 0 { + ttl = time.Hour + } + if ttl > vertexServiceAccountCacheSkew { + ttl -= vertexServiceAccountCacheSkew + } + return parsed.AccessToken, ttl, nil +} + +func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) { + projectID = strings.TrimSpace(projectID) + location = strings.TrimSpace(location) + model = strings.TrimSpace(model) + action = strings.TrimSpace(action) + if projectID == "" { + return "", errors.New("vertex project_id is required") + } + if location == "" { + location = vertexDefaultLocation + } + if !vertexLocationPattern.MatchString(location) { + return "", fmt.Errorf("invalid vertex location: %s", location) + } + if model == "" { + return "", errors.New("vertex model is required") + } + switch action { + case "generateContent", "streamGenerateContent", "countTokens": + default: + return "", fmt.Errorf("unsupported vertex gemini action: %s", action) + } + host := fmt.Sprintf("%s-aiplatform.googleapis.com", location) + if location == "global" { + host = "aiplatform.googleapis.com" + } + u := fmt.Sprintf( + "https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + host, + url.PathEscape(projectID), + url.PathEscape(location), + url.PathEscape(model), + action, + ) + if stream { + u += "?alt=sse" + } + return u, nil +} + +func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) { + projectID = strings.TrimSpace(projectID) + location = strings.TrimSpace(location) + model = strings.TrimSpace(model) + if projectID == "" { + return "", errors.New("vertex project_id is required") + } + if location == "" { + location = vertexDefaultLocation + } + if !vertexLocationPattern.MatchString(location) { + return "", fmt.Errorf("invalid vertex location: %s", location) + } + if model == "" { + return "", errors.New("vertex model is required") + } + action := "rawPredict" + if stream { + action = "streamRawPredict" + } + host := fmt.Sprintf("%s-aiplatform.googleapis.com", location) + if location == "global" { + host = "aiplatform.googleapis.com" + } + escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@") + return fmt.Sprintf( + "https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", + host, + url.PathEscape(projectID), + url.PathEscape(location), + escapedModel, + action, + ), nil +} + +func normalizeVertexAnthropicModelID(model string) string { + model = strings.TrimSpace(model) + if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) { + return model + } + if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 { + return m[1] + "@" + m[2] + } + return model +} + +func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("parse anthropic vertex request body: %w", err) + } + delete(payload, "model") + payload["anthropic_version"] = vertexAnthropicVersion + return json.Marshal(payload) +} diff --git a/backend/internal/service/vertex_service_account_test.go b/backend/internal/service/vertex_service_account_test.go new file mode 100644 index 00000000..519f5b2f --- /dev/null +++ b/backend/internal/service/vertex_service_account_test.go @@ -0,0 +1,77 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildVertexGeminiURL(t *testing.T) { + got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true) + require.NoError(t, err) + require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got) +} + +func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) { + got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true) + require.NoError(t, err) + require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got) +} + +func TestBuildVertexAnthropicURL(t *testing.T) { + got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false) + require.NoError(t, err) + require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got) +} + +func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) { + got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true) + require.NoError(t, err) + require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got) +} + +func TestNormalizeVertexAnthropicModelID(t *testing.T) { + require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929")) + require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929")) + require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6")) +} + +func TestBuildVertexAnthropicRequestBody(t *testing.T) { + got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)) + require.NoError(t, err) + require.Equal(t, "", gjson.GetBytes(got, "model").String()) + require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String()) + require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int()) + require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String()) +} + +func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) { + _, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vertex location") +} + +func TestParseVertexServiceAccountKey(t *testing.T) { + raw := `{ + "type": "service_account", + "project_id": "vertex-proj", + "private_key_id": "kid", + "private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n", + "client_email": "svc@vertex-proj.iam.gserviceaccount.com" + }` + account := &Account{ + Type: AccountTypeServiceAccount, + Platform: PlatformGemini, + Credentials: map[string]any{ + "service_account_json": raw, + }, + } + key, err := parseVertexServiceAccountKey(account) + require.NoError(t, err) + require.Equal(t, "vertex-proj", key.ProjectID) + require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail) + require.Equal(t, vertexDefaultTokenURL, key.TokenURI) + require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY")) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d1a50913..1b7cb7ac 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -407,12 +407,28 @@ func ProvideBillingCacheService( return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) } +// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation. +func ProvideAPIKeyService( + apiKeyRepo APIKeyRepository, + userRepo UserRepository, + groupRepo GroupRepository, + userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, + cache APIKeyCache, + cfg *config.Config, + billingCacheService *BillingCacheService, +) *APIKeyService { + svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg) + svc.SetRateLimitCacheInvalidator(billingCacheService) + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, - NewAPIKeyService, + ProvideAPIKeyService, ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService, diff --git a/backend/migrations/133_affiliate_rebate_freeze.sql b/backend/migrations/133_affiliate_rebate_freeze.sql new file mode 100644 index 00000000..b87d59b7 --- /dev/null +++ b/backend/migrations/133_affiliate_rebate_freeze.sql @@ -0,0 +1,17 @@ +-- 1) Add frozen quota column to user_affiliates for rebate freeze period. +ALTER TABLE user_affiliates + ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0; + +COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)'; + +-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking. +-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp. +ALTER TABLE user_affiliate_ledger + ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL; + +COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen'; + +-- 3) Partial index for efficient thaw queries (only rows still frozen). +CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw + ON user_affiliate_ledger (user_id, frozen_until) + WHERE frozen_until IS NOT NULL; diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index a484d7ed..07a68c03 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -74,6 +74,26 @@ describe('oauth adoption auth api', () => { }) }) + it('posts affiliate code when completing linuxdo oauth registration', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration( + 'invite-code', + { + adoptDisplayName: true, + adoptAvatar: false + }, + ' AFF123 ' + ) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + aff_code: 'AFF123', + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts oidc invitation completion with adoption decisions', async () => { const { completeOIDCOAuthRegistration } = await import('@/api/auth') @@ -134,6 +154,26 @@ describe('oauth adoption auth api', () => { }) }) + it('posts affiliate code when creating pending wechat oauth account', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount( + 'invite-code', + { + adoptDisplayName: false, + adoptAvatar: true + }, + 'WXAFF' + ) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + aff_code: 'WXAFF', + adopt_display_name: false, + adopt_avatar: true + }) + }) + it('classifies oauth completion results as login or bind', async () => { const { getOAuthCompletionKind } = await import('@/api/auth') diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index a146f1f7..8a127793 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: { * @returns Success confirmation */ export async function bulkUpdate( - accountIds: number[], - updates: Record + accountIdsOrPayload: number[] | Record, + updates?: Record ): Promise<{ success: number failed: number @@ -379,16 +379,19 @@ export async function bulkUpdate( failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> }> { + const payload = Array.isArray(accountIdsOrPayload) + ? { + account_ids: accountIdsOrPayload, + ...(updates ?? {}) + } + : accountIdsOrPayload const { data } = await apiClient.post<{ success: number failed: number success_ids?: number[] failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> - }>('/admin/accounts/bulk-update', { - account_ids: accountIds, - ...updates - }) + }>('/admin/accounts/bulk-update', payload) return data } diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index cf8626fc..b887355a 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -309,6 +309,9 @@ export interface SystemSettings { // Default settings default_balance: number; affiliate_rebate_rate: number; + affiliate_rebate_freeze_hours: number; + affiliate_rebate_duration_days: number; + affiliate_rebate_per_invitee_cap: number; default_concurrency: number; default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; @@ -437,6 +440,7 @@ export interface SystemSettings { enable_fingerprint_unification: boolean; enable_metadata_passthrough: boolean; enable_cch_signing: boolean; + enable_anthropic_cache_ttl_1h_injection: boolean; web_search_emulation_enabled?: boolean; // Payment configuration @@ -482,6 +486,9 @@ export interface SystemSettings { // Affiliate (邀请返利) feature switch affiliate_enabled: boolean; + + // OpenAI fast/flex policy + openai_fast_policy_settings?: OpenAIFastPolicySettings; } export interface UpdateSettingsRequest { @@ -495,6 +502,9 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; affiliate_rebate_rate?: number; + affiliate_rebate_freeze_hours?: number; + affiliate_rebate_duration_days?: number; + affiliate_rebate_per_invitee_cap?: number; default_concurrency?: number; default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; @@ -602,6 +612,7 @@ export interface UpdateSettingsRequest { enable_fingerprint_unification?: boolean; enable_metadata_passthrough?: boolean; enable_cch_signing?: boolean; + enable_anthropic_cache_ttl_1h_injection?: boolean; // Payment configuration payment_enabled?: boolean; payment_min_amount?: number; @@ -644,6 +655,9 @@ export interface UpdateSettingsRequest { // Affiliate (邀请返利) feature switch affiliate_enabled?: boolean; + + // OpenAI fast/flex policy + openai_fast_policy_settings?: OpenAIFastPolicySettings; } /** @@ -871,6 +885,29 @@ export async function updateRectifierSettings( return data; } +// ==================== OpenAI Fast Policy Settings ==================== + +/** + * OpenAI fast/flex policy rule interface. + * Matches backend dto.OpenAIFastPolicyRule. + */ +export interface OpenAIFastPolicyRule { + service_tier: "all" | "priority" | "flex"; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; +} + +/** + * OpenAI fast/flex policy settings interface. + */ +export interface OpenAIFastPolicySettings { + rules: OpenAIFastPolicyRule[]; +} + // ==================== Beta Policy Settings ==================== /** diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index f49f3a1f..bb990fc4 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -564,9 +564,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { - return createPendingLinuxDoOAuthAccount(invitationCode, decision) + return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode) } /** @@ -576,27 +577,32 @@ export async function completeLinuxDoOAuthRegistration( */ export async function completeOIDCOAuthRegistration( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOIDCOAuthAccount(invitationCode, decision) + return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode) } export async function completeWeChatOAuthRegistration( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingWeChatOAuthAccount(invitationCode, decision) + return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode) } async function createPendingOAuthAccount( provider: 'linuxdo' | 'oidc' | 'wechat', invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { + const normalizedAffiliateCode = affiliateCode?.trim() const { data } = await apiClient.post( `/auth/oauth/${provider}/complete-registration`, { invitation_code: invitationCode, + ...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}), ...serializeOAuthAdoptionDecision(decision) } ) @@ -605,23 +611,26 @@ async function createPendingOAuthAccount( export async function createPendingLinuxDoOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('linuxdo', invitationCode, decision) + return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode) } export async function createPendingOIDCOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('oidc', invitationCode, decision) + return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode) } export async function createPendingWeChatOAuthAccount( invitationCode: string, - decision?: OAuthAdoptionDecision + decision?: OAuthAdoptionDecision, + affiliateCode?: string ): Promise { - return createPendingOAuthAccount('wechat', invitationCode, decision) + return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode) } export async function completePendingOAuthBindLogin( diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 41dd1505..90a67922 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -332,6 +332,37 @@
+
+
+ + {{ formatKeyRequests }} req + + + {{ formatKeyTokens }} + + + A ${{ formatKeyCost }} + + + U ${{ formatKeyUserCost }} + +
+
+
+
+
+
+
@@ -545,6 +576,10 @@ const shouldFetchUsage = computed(() => { return false }) +const showGeminiTodayStats = computed(() => { + return props.account.platform === 'gemini' && props.account.type === 'service_account' +}) + const geminiUsageAvailable = computed(() => { return ( !!usageInfo.value?.gemini_shared_daily || diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 13c30cf9..05016a6d 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -17,7 +17,7 @@ d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /> - {{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }} + {{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}

@@ -27,7 +27,7 @@ - {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }} + {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}

@@ -227,7 +227,7 @@

@@ -698,6 +698,87 @@

+ +
+
+ + +
+
+

+ {{ t('admin.accounts.openai.codexCLIOnlyDesc') }} +

+ +
+
+ + +
+
+ + +
+
+

+ {{ t('admin.accounts.openai.wsModeDesc') }} +

+

+ {{ t(openAIAPIKeyWSModeConcurrencyHintKey) }} +

+
+ +
+
+ + +
+
+
+
+ + {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }} +
+

+ {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }} +

+
+ +
+
+
Project ID: {{ vertexProjectId }}
+
Client Email: {{ vertexClientEmail }}
+
+
+

{{ t('admin.accounts.vertexSaJsonUploadHint') }}

+
+ +
+
+ + +
+
+ + +

{{ t('admin.accounts.vertexLocationHint') }}

+
+
+
+
@@ -3119,6 +3280,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' +import { VERTEX_LOCATION_OPTIONS } from '@/constants/account' import { OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, @@ -3233,7 +3395,7 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') @@ -3306,6 +3468,12 @@ const bedrockSessionToken = ref('') const bedrockRegion = ref('us-east-1') const bedrockForceGlobal = ref(false) const bedrockApiKeyValue = ref('') +const vertexServiceAccountFileInput = ref(null) +const vertexServiceAccountJson = ref('') +const vertexProjectId = ref('') +const vertexClientEmail = ref('') +const vertexLocation = ref('global') +const vertexServiceAccountDragActive = ref(false) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -3556,7 +3724,7 @@ watch( // Sync form.type based on accountCategory, addMethod, and platform-specific type watch( - [accountCategory, addMethod, antigravityAccountType], + [accountCategory, addMethod, antigravityAccountType, () => form.platform], ([category, method, agType]) => { // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { @@ -3568,7 +3736,9 @@ watch( form.type = 'bedrock' as AccountType return } - if (category === 'oauth-based') { + if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') { + form.type = 'service_account' as AccountType + } else if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { form.type = 'apikey' @@ -3606,6 +3776,12 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') { + accountCategory.value = 'oauth-based' + } + if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') { + accountCategory.value = 'oauth-based' + } // Reset Bedrock fields when switching platforms bedrockAccessKeyId.value = '' bedrockSecretAccessKey.value = '' @@ -3614,6 +3790,10 @@ watch( bedrockForceGlobal.value = false bedrockAuthMode.value = 'sigv4' bedrockApiKeyValue.value = '' + vertexServiceAccountJson.value = '' + vertexProjectId.value = '' + vertexClientEmail.value = '' + vertexLocation.value = 'global' // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -4068,6 +4248,10 @@ const resetForm = () => { upstreamType.value = 'sub2api' upstreamBaseUrl.value = '' upstreamApiKey.value = '' + vertexServiceAccountJson.value = '' + vertexProjectId.value = '' + vertexClientEmail.value = '' + vertexLocation.value = 'global' tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -4195,6 +4379,52 @@ const normalizePoolModeRetryCount = (value: number) => { return normalized } +const applyVertexServiceAccountJson = (value: string) => { + const raw = value.trim() + if (!raw) { + vertexProjectId.value = '' + vertexClientEmail.value = '' + return false + } + try { + const parsed = JSON.parse(raw) as Record + const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : '' + const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : '' + const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : '' + if (!projectId || !clientEmail || !privateKey) { + appStore.showError(t('admin.accounts.vertexSaJsonMissingFields')) + return false + } + vertexProjectId.value = projectId + vertexClientEmail.value = clientEmail + vertexServiceAccountJson.value = JSON.stringify(parsed) + return true + } catch { + appStore.showError(t('admin.accounts.vertexSaJsonInvalid')) + return false + } +} + +const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value) + +const handleVertexServiceAccountFile = async (event: Event) => { + const input = event.target as HTMLInputElement + const file = input.files?.[0] + if (!file) return + try { + applyVertexServiceAccountJson(await file.text()) + } finally { + input.value = '' + } +} + +const handleVertexServiceAccountDrop = async (event: DragEvent) => { + vertexServiceAccountDragActive.value = false + const file = event.dataTransfer?.files?.[0] + if (!file) return + applyVertexServiceAccountJson(await file.text()) +} + const handleSubmit = async () => { // For OAuth-based type, handle OAuth flow (goes to step 2) if (isOAuthFlow.value) { @@ -4348,6 +4578,29 @@ const handleSubmit = async () => { return } + if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!parseVertexServiceAccountJson()) { + return + } + if (!vertexLocation.value.trim()) { + appStore.showError(t('admin.accounts.vertexLocationRequired')) + return + } + const credentials: Record = { + service_account_json: vertexServiceAccountJson.value.trim(), + project_id: vertexProjectId.value.trim(), + client_email: vertexClientEmail.value.trim(), + location: vertexLocation.value.trim(), + tier_id: 'vertex' + } + await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials) + return + } + // For apikey type, create directly if (!apiKeyValue.value.trim()) { appStore.showError(t('admin.accounts.pleaseEnterApiKey')) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 85f05395..b117bff3 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -599,6 +599,221 @@
+ +
+
+
+ + +

{{ t('admin.accounts.vertexSaJsonEditHint') }}

+
+
+ + +

{{ t('admin.accounts.vertexLocationHint') }}

+
+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ + t('admin.accounts.supportsAllModels') + }} +

+
+ + +
+
+

+ + + + {{ t('admin.accounts.mapRequestModels') }} +

+
+ + +
+
+ + + + + + +
+
+ + + + +
+ +
+
+
+
+
@@ -1951,6 +2166,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' +import { VERTEX_LOCATION_OPTIONS } from '@/constants/account' import { OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, @@ -2020,6 +2236,9 @@ const editBedrockSessionToken = ref('') const editBedrockRegion = ref('') const editBedrockForceGlobal = ref(false) const editBedrockApiKeyValue = ref('') +const editVertexProjectId = ref('') +const editVertexClientEmail = ref('') +const editVertexLocation = ref('us-central1') const isBedrockAPIKeyMode = computed(() => props.account?.type === 'bedrock' && (props.account?.credentials as Record)?.auth_mode === 'apikey' @@ -2279,6 +2498,9 @@ const syncFormFromAccount = (newAccount: Account | null) => { const credentials = newAccount.credentials as Record | undefined interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true + editVertexProjectId.value = '' + editVertexClientEmail.value = '' + editVertexLocation.value = 'us-central1' // Load mixed scheduling setting (only for antigravity accounts) mixedScheduling.value = false @@ -2504,6 +2726,31 @@ const syncFormFromAccount = (newAccount: Account | null) => { } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' + } else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editVertexProjectId.value = (credentials.project_id as string) || '' + editVertexClientEmail.value = (credentials.client_email as string) || '' + editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1' + + // Load model mappings for service_account + const existingMappings = credentials.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -3099,6 +3346,46 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + if (!editVertexProjectId.value.trim()) { + appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId')) + return + } + if (!editVertexClientEmail.value.trim()) { + appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail')) + return + } + if (!editVertexLocation.value.trim()) { + appStore.showError(t('admin.accounts.vertexLocationRequired')) + return + } + + if (!currentCredentials.service_account_json && !currentCredentials.service_account) { + appStore.showError(t('admin.accounts.vertexSaJsonRequired')) + return + } + newCredentials.project_id = editVertexProjectId.value.trim() + newCredentials.client_email = editVertexClientEmail.value.trim() + newCredentials.location = editVertexLocation.value.trim() + newCredentials.tier_id = 'vertex' + + // Add model mapping if configured + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else if (props.account.type === 'bedrock') { const currentCredentials = (props.account.credentials as Record) || {} diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts index 9158da64..fa4104f6 100644 --- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts +++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts @@ -57,6 +57,19 @@ function makeAccount(overrides: Partial): Account { describe('AccountUsageCell', () => { beforeEach(() => { getUsage.mockReset() + Object.defineProperty(window, 'matchMedia', { + writable: true, + value: vi.fn().mockImplementation(() => ({ + matches: true, + media: '(min-width: 768px)', + onchange: null, + addListener: vi.fn(), + removeListener: vi.fn(), + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + dispatchEvent: vi.fn(), + })) + }) }) it('Antigravity 图片用量会聚合新旧 image 模型', async () => { @@ -603,4 +616,43 @@ describe('AccountUsageCell', () => { expect(wrapper.text().trim()).toBe('-') }) + + it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => { + const wrapper = mount(AccountUsageCell, { + props: { + account: makeAccount({ + id: 4001, + platform: 'gemini', + type: 'service_account', + credentials: { + tier_id: 'vertex', + project_id: 'vertex-proj', + client_email: 'svc@vertex-proj.iam.gserviceaccount.com', + location: 'global' + }, + extra: {} + }), + todayStats: { + requests: 0, + tokens: 0, + cost: 0, + standard_cost: 0, + user_cost: 0 + } + }, + global: { + stubs: { + UsageProgressBar: true, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('0 req') + expect(wrapper.text()).toContain('0') + expect(wrapper.text()).toContain('A $0.00') + expect(wrapper.text()).toContain('U $0.00') + }) }) diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts index 7390e723..50d170da 100644 --- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts @@ -178,6 +178,45 @@ describe('BulkEditAccountModal', () => { expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false) }) + it('OpenAI OAuth 批量编辑应提交 codex_cli_only 字段', async () => { + const wrapper = mountModal({ + selectedPlatforms: ['openai'], + selectedTypes: ['oauth'] + }) + + await wrapper.get('#bulk-edit-openai-codex-cli-only-enabled').setValue(true) + await wrapper.get('#bulk-edit-openai-codex-cli-only-toggle').trigger('click') + await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent') + await flushPromises() + + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1) + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], { + extra: { + codex_cli_only: true + } + }) + }) + + it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => { + const wrapper = mountModal({ + selectedPlatforms: ['openai'], + selectedTypes: ['apikey'] + }) + + await wrapper.get('#bulk-edit-openai-apikey-ws-mode-enabled').setValue(true) + await wrapper.get('[data-testid="bulk-edit-openai-apikey-ws-mode-select"]').setValue('ctx_pool') + await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent') + await flushPromises() + + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1) + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], { + extra: { + openai_apikey_responses_websockets_v2_mode: 'ctx_pool', + openai_apikey_responses_websockets_v2_enabled: true + } + }) + }) + it('OpenAI 账号批量编辑可关闭自动透传', async () => { const wrapper = mountModal({ selectedPlatforms: ['openai'], @@ -217,4 +256,41 @@ describe('BulkEditAccountModal', () => { }) expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }) + + it('filtered-results 模式下应提交 filters 而不是 account_ids', async () => { + const wrapper = mountModal({ + accountIds: [], + target: { + mode: 'filtered', + filters: { + platform: 'openai', + type: 'oauth', + status: 'active', + group: '12', + search: 'bulk-target', + privacy_mode: 'training_set_cf_blocked' + }, + previewCount: 5, + selectedPlatforms: ['openai'], + selectedTypes: ['oauth'] + } + }) + + await wrapper.get('#bulk-edit-status-enabled').setValue(true) + await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent') + await flushPromises() + + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1) + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({ + filters: { + platform: 'openai', + type: 'oauth', + status: 'active', + group: '12', + search: 'bulk-target', + privacy_mode: 'training_set_cf_blocked' + }, + status: 'active' + }) + }) }) diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue index 3b987bd0..a632bdd4 100644 --- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue +++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue @@ -1,9 +1,13 @@ diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue index c740d06f..6b245123 100644 --- a/frontend/src/components/auth/LinuxDoOAuthSection.vue +++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue @@ -42,9 +42,11 @@