diff --git a/.github/audit-exceptions.yml b/.github/audit-exceptions.yml index a1d8411c..b71422a7 100644 --- a/.github/audit-exceptions.yml +++ b/.github/audit-exceptions.yml @@ -5,12 +5,26 @@ exceptions: severity: high reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2023-30533)" mitigation: "Load only on export; restrict export permissions and data scope" - expires_on: "2026-04-05" + expires_on: "2026-07-06" owner: "security@your-domain" - package: xlsx advisory: "GHSA-5pgg-2g8v-p4x9" severity: high reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2024-22363)" mitigation: "Load only on export; restrict export permissions and data scope" - expires_on: "2026-04-05" + expires_on: "2026-07-06" + owner: "security@your-domain" + - package: lodash + advisory: "GHSA-r5fr-rjxr-66jc" + severity: high + reason: "lodash _.template not used with untrusted input; only internal admin UI templates" + mitigation: "No user-controlled template strings; plan to migrate to lodash-es tree-shaken imports" + expires_on: "2026-07-02" + owner: "security@your-domain" + - package: lodash-es + advisory: "GHSA-r5fr-rjxr-66jc" + severity: high + reason: "lodash-es _.template not used with untrusted input; only internal admin UI templates" + mitigation: "No user-controlled template strings; plan to migrate to native JS alternatives" + expires_on: "2026-07-02" owner: "security@your-domain" diff --git a/README.md b/README.md index 99753e45..50611a6d 100644 --- a/README.md +++ b/README.md @@ -45,17 +45,30 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - **Admin Dashboard** - Web interface for monitoring and management - **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard -## Don't Want to Self-Host? +## ❤️ Sponsors + +> [Want to appear here?](mailto:support@pincc.ai) + + + + + + + + + + +
pincc PinCC is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.
PackyCode Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "sub2api" promo code during first recharge to get 10% off.
PoixeAiThanks to Poixe Ai for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive sub2api referral link and receive a bonus of $5 USD on your first top-up.
CTokThanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click here to register!
## Ecosystem diff --git a/README_CN.md b/README_CN.md index 8b6feaba..797f106b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -44,17 +44,31 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 - **管理后台** - Web 界面进行监控和管理 - **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能 -## 不想自建?试试官方中转 +## ❤️ 赞助商 + +> [想出现在这里?](mailto:support@pincc.ai) + + + + + + + + + + + +
pincc PinCC 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。
PackyCode 感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用此链接注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!
PoixeAI感谢 Poixe AI 赞助了本项目!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 此链接 专属链接注册,充值额外赠送 $5 美金
CTok感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击这里注册!
## 生态项目 diff --git a/README_JA.md b/README_JA.md index 1266bd84..b7820554 100644 --- a/README_JA.md +++ b/README_JA.md @@ -45,7 +45,9 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを - **管理ダッシュボード** - 監視・管理のための Web インターフェース - **外部システム連携** - 外部システム(決済、チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能 -## セルフホストが不要な方へ +## ❤️ スポンサー + +> [こちらに掲載しませんか?](mailto:support@pincc.ai) @@ -56,6 +58,16 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを + + + + + + + + + +
PackyCode PackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:このリンクで登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。
PoixeAiPoixe AI のご支援に感謝します!Poixe AI は信頼性の高い LLM API サービスを提供しています。プラットフォームの API エンドポイントを活用して、AI 搭載プロダクトをシームレスに構築できます。また、ベンダーとして AI API リソースをプラットフォームに提供し、収益を得ることも可能です。専用の sub2api 紹介リンクから登録すると、初回チャージ時に $5 USD のボーナスがもらえます。
CTokCTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。こちらから登録!
## エコシステム diff --git a/assets/partners/logos/ctok.png b/assets/partners/logos/ctok.png new file mode 100644 index 00000000..cf6fcf17 Binary files /dev/null and b/assets/partners/logos/ctok.png differ diff --git a/assets/partners/logos/poixe.png b/assets/partners/logos/poixe.png new file mode 100644 index 00000000..aa89cb06 Binary files /dev/null and b/assets/partners/logos/poixe.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 9e3db2aa..94e74b8e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.106 +0.1.109 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index d1070b88..87c5355e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -54,6 +54,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) + channelRepository := repository.NewChannelRepository(db) settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) @@ -106,12 +107,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) - soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -143,11 +143,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) @@ -180,18 +180,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) + modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - soraS3Storage := service.NewSoraS3Storage(settingService) - settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) - soraGenerationRepository := repository.NewSoraGenerationRepository(db) - soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) - soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -218,22 +215,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) + channelHandler := admin.NewChannelHandler(channelService, billingService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) - soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) - soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) - soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) - soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -244,13 +237,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) lsPoolBootstrapService := service.ProvideLSPoolBootstrapService(accountRepository, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, lsPoolBootstrapService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, lsPoolBootstrapService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) application := &Application{ Server: httpServer, Cleanup: v, @@ -285,7 +277,6 @@ func provideCleanup( opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, opsSystemLogSink *service.OpsSystemLogSink, - soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, lsPoolBootstrap *service.LSPoolBootstrapService, @@ -334,12 +325,6 @@ func provideCleanup( } return nil }}, - {"SoraMediaCleanupService", func() error { - if soraMediaCleanup != nil { - soraMediaCleanup.Stop() - } - return nil - }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index eb7a0877..52f4aa3c 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -58,7 +58,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { &service.OpsCleanupService{}, &service.OpsScheduledReportService{}, opsSystemLogSinkSvc, - &service.SoraMediaCleanupService{}, schedulerSnapshotSvc, tokenRefreshSvc, lsPoolBootstrapSvc, diff --git a/backend/ent/group.go b/backend/ent/group.go index 3db54a64..b15ac15d 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,16 +52,6 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` - // SoraImagePrice360 holds the value of the "sora_image_price_360" field. - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - // SoraImagePrice540 holds the value of the "sora_image_price_540" field. - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` - // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. - SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` - // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -80,6 +70,10 @@ type Group struct { SortOrder int `json:"sort_order,omitempty"` // 是否允许 /v1/messages 调度到此 OpenAI 分组 AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` + // 仅允许非 apikey 类型账号关联到此分组 + RequireOauthOnly bool `json:"require_oauth_only,omitempty"` + // 调度时仅允许 privacy 已成功设置的账号 + RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -190,11 +184,11 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) @@ -331,40 +325,6 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } - case group.FieldSoraImagePrice360: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) - } else if value.Valid { - _m.SoraImagePrice360 = new(float64) - *_m.SoraImagePrice360 = value.Float64 - } - case group.FieldSoraImagePrice540: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) - } else if value.Valid { - _m.SoraImagePrice540 = new(float64) - *_m.SoraImagePrice540 = value.Float64 - } - case group.FieldSoraVideoPricePerRequest: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) - } else if value.Valid { - _m.SoraVideoPricePerRequest = new(float64) - *_m.SoraVideoPricePerRequest = value.Float64 - } - case group.FieldSoraVideoPricePerRequestHd: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) - } else if value.Valid { - _m.SoraVideoPricePerRequestHd = new(float64) - *_m.SoraVideoPricePerRequestHd = value.Float64 - } - case group.FieldSoraStorageQuotaBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageQuotaBytes = value.Int64 - } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -425,6 +385,18 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.AllowMessagesDispatch = value.Bool } + case group.FieldRequireOauthOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_oauth_only", values[i]) + } else if value.Valid { + _m.RequireOauthOnly = value.Bool + } + case group.FieldRequirePrivacySet: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field require_privacy_set", values[i]) + } else if value.Valid { + _m.RequirePrivacySet = value.Bool + } case group.FieldDefaultMappedModel: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field default_mapped_model", values[i]) @@ -574,29 +546,6 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") - if v := _m.SoraImagePrice360; v != nil { - builder.WriteString("sora_image_price_360=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraImagePrice540; v != nil { - builder.WriteString("sora_image_price_540=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraVideoPricePerRequest; v != nil { - builder.WriteString("sora_video_price_per_request=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraVideoPricePerRequestHd; v != nil { - builder.WriteString("sora_video_price_per_request_hd=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - builder.WriteString("sora_storage_quota_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) - builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") @@ -628,6 +577,12 @@ func (_m *Group) String() string { builder.WriteString("allow_messages_dispatch=") builder.WriteString(fmt.Sprintf("%v", _m.AllowMessagesDispatch)) builder.WriteString(", ") + builder.WriteString("require_oauth_only=") + builder.WriteString(fmt.Sprintf("%v", _m.RequireOauthOnly)) + builder.WriteString(", ") + builder.WriteString("require_privacy_set=") + builder.WriteString(fmt.Sprintf("%v", _m.RequirePrivacySet)) + builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) builder.WriteByte(')') diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 2612b6cf..21a7c2cb 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,16 +49,6 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" - // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. - FieldSoraImagePrice360 = "sora_image_price_360" - // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. - FieldSoraImagePrice540 = "sora_image_price_540" - // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. - FieldSoraVideoPricePerRequest = "sora_video_price_per_request" - // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. - FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" - // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. - FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -77,6 +67,10 @@ const ( FieldSortOrder = "sort_order" // FieldAllowMessagesDispatch holds the string denoting the allow_messages_dispatch field in the database. FieldAllowMessagesDispatch = "allow_messages_dispatch" + // FieldRequireOauthOnly holds the string denoting the require_oauth_only field in the database. + FieldRequireOauthOnly = "require_oauth_only" + // FieldRequirePrivacySet holds the string denoting the require_privacy_set field in the database. + FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. @@ -171,11 +165,6 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, - FieldSoraImagePrice360, - FieldSoraImagePrice540, - FieldSoraVideoPricePerRequest, - FieldSoraVideoPricePerRequestHd, - FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -185,6 +174,8 @@ var Columns = []string{ FieldSupportedModelScopes, FieldSortOrder, FieldAllowMessagesDispatch, + FieldRequireOauthOnly, + FieldRequirePrivacySet, FieldDefaultMappedModel, } @@ -241,8 +232,6 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int - // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. - DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -255,6 +244,10 @@ var ( DefaultSortOrder int // DefaultAllowMessagesDispatch holds the default value on creation for the "allow_messages_dispatch" field. DefaultAllowMessagesDispatch bool + // DefaultRequireOauthOnly holds the default value on creation for the "require_oauth_only" field. + DefaultRequireOauthOnly bool + // DefaultRequirePrivacySet holds the default value on creation for the "require_privacy_set" field. + DefaultRequirePrivacySet bool // DefaultDefaultMappedModel holds the default value on creation for the "default_mapped_model" field. DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -354,31 +347,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } -// BySoraImagePrice360 orders the results by the sora_image_price_360 field. -func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() -} - -// BySoraImagePrice540 orders the results by the sora_image_price_540 field. -func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() -} - -// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. -func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() -} - -// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. -func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() -} - -// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. -func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() -} - // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() @@ -414,6 +382,16 @@ func ByAllowMessagesDispatch(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldAllowMessagesDispatch, opts...).ToFunc() } +// ByRequireOauthOnly orders the results by the require_oauth_only field. +func ByRequireOauthOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequireOauthOnly, opts...).ToFunc() +} + +// ByRequirePrivacySet orders the results by the require_privacy_set field. +func ByRequirePrivacySet(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequirePrivacySet, opts...).ToFunc() +} + // ByDefaultMappedModel orders the results by the default_mapped_model field. func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 5dd8759e..cba2ce5f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } -// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. -func SoraImagePrice360(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. -func SoraImagePrice540(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) -} - -// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. -func SoraVideoPricePerRequest(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. -func SoraVideoPricePerRequestHd(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. -func SoraStorageQuotaBytes(v int64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -200,6 +175,16 @@ func AllowMessagesDispatch(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnly applies equality check predicate on the "require_oauth_only" field. It's identical to RequireOauthOnlyEQ. +func RequireOauthOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySet applies equality check predicate on the "require_privacy_set" field. It's identical to RequirePrivacySetEQ. +func RequirePrivacySet(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModel applies equality check predicate on the "default_mapped_model" field. It's identical to DefaultMappedModelEQ. func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) @@ -1060,246 +1045,6 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } -// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. -func SoraImagePrice360EQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. -func SoraImagePrice360NEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. -func SoraImagePrice360In(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) -} - -// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. -func SoraImagePrice360NotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) -} - -// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. -func SoraImagePrice360GT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. -func SoraImagePrice360GTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. -func SoraImagePrice360LT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. -func SoraImagePrice360LTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. -func SoraImagePrice360IsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) -} - -// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. -func SoraImagePrice360NotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) -} - -// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. -func SoraImagePrice540EQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. -func SoraImagePrice540NEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. -func SoraImagePrice540In(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) -} - -// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. -func SoraImagePrice540NotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) -} - -// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. -func SoraImagePrice540GT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. -func SoraImagePrice540GTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. -func SoraImagePrice540LT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. -func SoraImagePrice540LTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. -func SoraImagePrice540IsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) -} - -// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. -func SoraImagePrice540NotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) -} - -// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) -} - -// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) -} - -// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestGT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestGTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestLT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestLTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestIsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) -} - -// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) -} - -// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) -} - -// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) -} - -// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdIsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) -} - -// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) -} - -// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesEQ(v int64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGT(v int64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGTE(v int64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLT(v int64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLTE(v int64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) -} - // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1490,6 +1235,26 @@ func AllowMessagesDispatchNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldAllowMessagesDispatch, v)) } +// RequireOauthOnlyEQ applies the EQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequireOauthOnly, v)) +} + +// RequireOauthOnlyNEQ applies the NEQ predicate on the "require_oauth_only" field. +func RequireOauthOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequireOauthOnly, v)) +} + +// RequirePrivacySetEQ applies the EQ predicate on the "require_privacy_set" field. +func RequirePrivacySetEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRequirePrivacySet, v)) +} + +// RequirePrivacySetNEQ applies the NEQ predicate on the "require_privacy_set" field. +func RequirePrivacySetNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRequirePrivacySet, v)) +} + // DefaultMappedModelEQ applies the EQ predicate on the "default_mapped_model" field. func DefaultMappedModelEQ(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 6db5b974..a8c30b18 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,76 +258,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { - _c.mutation.SetSoraImagePrice360(v) - return _c -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraImagePrice360(*v) - } - return _c -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { - _c.mutation.SetSoraImagePrice540(v) - return _c -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraImagePrice540(*v) - } - return _c -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { - _c.mutation.SetSoraVideoPricePerRequest(v) - return _c -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraVideoPricePerRequest(*v) - } - return _c -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { - _c.mutation.SetSoraVideoPricePerRequestHd(v) - return _c -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraVideoPricePerRequestHd(*v) - } - return _c -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { - _c.mutation.SetSoraStorageQuotaBytes(v) - return _c -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { - if v != nil { - _c.SetSoraStorageQuotaBytes(*v) - } - return _c -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -438,6 +368,34 @@ func (_c *GroupCreate) SetNillableAllowMessagesDispatch(v *bool) *GroupCreate { return _c } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_c *GroupCreate) SetRequireOauthOnly(v bool) *GroupCreate { + _c.mutation.SetRequireOauthOnly(v) + return _c +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequireOauthOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetRequireOauthOnly(*v) + } + return _c +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_c *GroupCreate) SetRequirePrivacySet(v bool) *GroupCreate { + _c.mutation.SetRequirePrivacySet(v) + return _c +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRequirePrivacySet(v *bool) *GroupCreate { + if v != nil { + _c.SetRequirePrivacySet(*v) + } + return _c +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_c *GroupCreate) SetDefaultMappedModel(v string) *GroupCreate { _c.mutation.SetDefaultMappedModel(v) @@ -617,10 +575,6 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - v := group.DefaultSoraStorageQuotaBytes - _c.mutation.SetSoraStorageQuotaBytes(v) - } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -645,6 +599,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultAllowMessagesDispatch _c.mutation.SetAllowMessagesDispatch(v) } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + v := group.DefaultRequireOauthOnly + _c.mutation.SetRequireOauthOnly(v) + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + v := group.DefaultRequirePrivacySet + _c.mutation.SetRequirePrivacySet(v) + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) @@ -701,9 +663,6 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} - } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -722,6 +681,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.AllowMessagesDispatch(); !ok { return &ValidationError{Name: "allow_messages_dispatch", err: errors.New(`ent: missing required field "Group.allow_messages_dispatch"`)} } + if _, ok := _c.mutation.RequireOauthOnly(); !ok { + return &ValidationError{Name: "require_oauth_only", err: errors.New(`ent: missing required field "Group.require_oauth_only"`)} + } + if _, ok := _c.mutation.RequirePrivacySet(); !ok { + return &ValidationError{Name: "require_privacy_set", err: errors.New(`ent: missing required field "Group.require_privacy_set"`)} + } if _, ok := _c.mutation.DefaultMappedModel(); !ok { return &ValidationError{Name: "default_mapped_model", err: errors.New(`ent: missing required field "Group.default_mapped_model"`)} } @@ -825,26 +790,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } - if value, ok := _c.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - _node.SoraImagePrice360 = &value - } - if value, ok := _c.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - _node.SoraImagePrice540 = &value - } - if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - _node.SoraVideoPricePerRequest = &value - } - if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - _node.SoraVideoPricePerRequestHd = &value - } - if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - _node.SoraStorageQuotaBytes = value - } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -881,6 +826,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) _node.AllowMessagesDispatch = value } + if value, ok := _c.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + _node.RequireOauthOnly = value + } + if value, ok := _c.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + _node.RequirePrivacySet = value + } if value, ok := _c.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value @@ -1329,120 +1282,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { - u.Set(group.FieldSoraImagePrice360, v) - return u -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { - u.SetExcluded(group.FieldSoraImagePrice360) - return u -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { - u.Add(group.FieldSoraImagePrice360, v) - return u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { - u.SetNull(group.FieldSoraImagePrice360) - return u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { - u.Set(group.FieldSoraImagePrice540, v) - return u -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { - u.SetExcluded(group.FieldSoraImagePrice540) - return u -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { - u.Add(group.FieldSoraImagePrice540, v) - return u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { - u.SetNull(group.FieldSoraImagePrice540) - return u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { - u.Set(group.FieldSoraVideoPricePerRequest, v) - return u -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { - u.SetExcluded(group.FieldSoraVideoPricePerRequest) - return u -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { - u.Add(group.FieldSoraVideoPricePerRequest, v) - return u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { - u.SetNull(group.FieldSoraVideoPricePerRequest) - return u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { - u.Set(group.FieldSoraVideoPricePerRequestHd, v) - return u -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { - u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) - return u -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { - u.Add(group.FieldSoraVideoPricePerRequestHd, v) - return u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { - u.SetNull(group.FieldSoraVideoPricePerRequestHd) - return u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { - u.Set(group.FieldSoraStorageQuotaBytes, v) - return u -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { - u.SetExcluded(group.FieldSoraStorageQuotaBytes) - return u -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { - u.Add(group.FieldSoraStorageQuotaBytes, v) - return u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1587,6 +1426,30 @@ func (u *GroupUpsert) UpdateAllowMessagesDispatch() *GroupUpsert { return u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsert) SetRequireOauthOnly(v bool) *GroupUpsert { + u.Set(group.FieldRequireOauthOnly, v) + return u +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequireOauthOnly() *GroupUpsert { + u.SetExcluded(group.FieldRequireOauthOnly) + return u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsert) SetRequirePrivacySet(v bool) *GroupUpsert { + u.Set(group.FieldRequirePrivacySet, v) + return u +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRequirePrivacySet() *GroupUpsert { + u.SetExcluded(group.FieldRequirePrivacySet) + return u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsert) SetDefaultMappedModel(v string) *GroupUpsert { u.Set(group.FieldDefaultMappedModel, v) @@ -1980,139 +1843,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice360(v) - }) -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice360(v) - }) -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice360() - }) -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice360() - }) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice540(v) - }) -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice540(v) - }) -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice540() - }) -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice540() - }) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequest(v) - }) -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequest(v) - }) -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequest() - }) -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequest() - }) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequestHd(v) - }) -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequestHd(v) - }) -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequestHd() - }) -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequestHd() - }) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2281,6 +2011,34 @@ func (u *GroupUpsertOne) UpdateAllowMessagesDispatch() *GroupUpsertOne { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertOne) SetRequireOauthOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequireOauthOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertOne) SetRequirePrivacySet(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRequirePrivacySet() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertOne) SetDefaultMappedModel(v string) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2842,139 +2600,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice360(v) - }) -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice360(v) - }) -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice360() - }) -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice360() - }) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice540(v) - }) -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice540(v) - }) -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice540() - }) -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice540() - }) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequest(v) - }) -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequest(v) - }) -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequest() - }) -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequest() - }) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequestHd(v) - }) -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequestHd(v) - }) -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequestHd() - }) -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequestHd() - }) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { @@ -3143,6 +2768,34 @@ func (u *GroupUpsertBulk) UpdateAllowMessagesDispatch() *GroupUpsertBulk { }) } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (u *GroupUpsertBulk) SetRequireOauthOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequireOauthOnly(v) + }) +} + +// UpdateRequireOauthOnly sets the "require_oauth_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequireOauthOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequireOauthOnly() + }) +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (u *GroupUpsertBulk) SetRequirePrivacySet(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRequirePrivacySet(v) + }) +} + +// UpdateRequirePrivacySet sets the "require_privacy_set" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRequirePrivacySet() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRequirePrivacySet() + }) +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (u *GroupUpsertBulk) SetDefaultMappedModel(v string) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index b3698596..aa1a83d4 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -355,135 +355,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { - _u.mutation.ResetSoraImagePrice360() - _u.mutation.SetSoraImagePrice360(v) - return _u -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraImagePrice360(*v) - } - return _u -} - -// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. -func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { - _u.mutation.AddSoraImagePrice360(v) - return _u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { - _u.mutation.ClearSoraImagePrice360() - return _u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { - _u.mutation.ResetSoraImagePrice540() - _u.mutation.SetSoraImagePrice540(v) - return _u -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraImagePrice540(*v) - } - return _u -} - -// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. -func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { - _u.mutation.AddSoraImagePrice540(v) - return _u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { - _u.mutation.ClearSoraImagePrice540() - return _u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { - _u.mutation.ResetSoraVideoPricePerRequest() - _u.mutation.SetSoraVideoPricePerRequest(v) - return _u -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraVideoPricePerRequest(*v) - } - return _u -} - -// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. -func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { - _u.mutation.AddSoraVideoPricePerRequest(v) - return _u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { - _u.mutation.ClearSoraVideoPricePerRequest() - return _u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { - _u.mutation.ResetSoraVideoPricePerRequestHd() - _u.mutation.SetSoraVideoPricePerRequestHd(v) - return _u -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraVideoPricePerRequestHd(*v) - } - return _u -} - -// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { - _u.mutation.AddSoraVideoPricePerRequestHd(v) - return _u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { - _u.mutation.ClearSoraVideoPricePerRequestHd() - return _u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -639,6 +510,34 @@ func (_u *GroupUpdate) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate { return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdate) SetRequireOauthOnly(v bool) *GroupUpdate { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequireOauthOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdate) SetRequirePrivacySet(v bool) *GroupUpdate { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRequirePrivacySet(v *bool) *GroupUpdate { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdate) SetDefaultMappedModel(v string) *GroupUpdate { _u.mutation.SetDefaultMappedModel(v) @@ -1054,48 +953,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } - if value, ok := _u.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { - _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice360Cleared() { - _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { - _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice540Cleared() { - _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestHdCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1146,6 +1003,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } @@ -1783,135 +1646,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraImagePrice360() - _u.mutation.SetSoraImagePrice360(v) - return _u -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraImagePrice360(*v) - } - return _u -} - -// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. -func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { - _u.mutation.AddSoraImagePrice360(v) - return _u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { - _u.mutation.ClearSoraImagePrice360() - return _u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraImagePrice540() - _u.mutation.SetSoraImagePrice540(v) - return _u -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraImagePrice540(*v) - } - return _u -} - -// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. -func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { - _u.mutation.AddSoraImagePrice540(v) - return _u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { - _u.mutation.ClearSoraImagePrice540() - return _u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraVideoPricePerRequest() - _u.mutation.SetSoraVideoPricePerRequest(v) - return _u -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraVideoPricePerRequest(*v) - } - return _u -} - -// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { - _u.mutation.AddSoraVideoPricePerRequest(v) - return _u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { - _u.mutation.ClearSoraVideoPricePerRequest() - return _u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraVideoPricePerRequestHd() - _u.mutation.SetSoraVideoPricePerRequestHd(v) - return _u -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraVideoPricePerRequestHd(*v) - } - return _u -} - -// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { - _u.mutation.AddSoraVideoPricePerRequestHd(v) - return _u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { - _u.mutation.ClearSoraVideoPricePerRequestHd() - return _u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2067,6 +1801,34 @@ func (_u *GroupUpdateOne) SetNillableAllowMessagesDispatch(v *bool) *GroupUpdate return _u } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (_u *GroupUpdateOne) SetRequireOauthOnly(v bool) *GroupUpdateOne { + _u.mutation.SetRequireOauthOnly(v) + return _u +} + +// SetNillableRequireOauthOnly sets the "require_oauth_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequireOauthOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequireOauthOnly(*v) + } + return _u +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (_u *GroupUpdateOne) SetRequirePrivacySet(v bool) *GroupUpdateOne { + _u.mutation.SetRequirePrivacySet(v) + return _u +} + +// SetNillableRequirePrivacySet sets the "require_privacy_set" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRequirePrivacySet(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetRequirePrivacySet(*v) + } + return _u +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (_u *GroupUpdateOne) SetDefaultMappedModel(v string) *GroupUpdateOne { _u.mutation.SetDefaultMappedModel(v) @@ -2512,48 +2274,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } - if value, ok := _u.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { - _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice360Cleared() { - _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { - _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice540Cleared() { - _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestHdCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -2604,6 +2324,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AllowMessagesDispatch(); ok { _spec.SetField(group.FieldAllowMessagesDispatch, field.TypeBool, value) } + if value, ok := _u.mutation.RequireOauthOnly(); ok { + _spec.SetField(group.FieldRequireOauthOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.RequirePrivacySet(); ok { + _spec.SetField(group.FieldRequirePrivacySet, field.TypeBool, value) + } if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index c472d7e0..5400bf93 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -395,11 +395,6 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -409,6 +404,8 @@ var ( {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false}, + {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, + {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, } // GroupsTable holds the schema information for the "groups" table. @@ -445,7 +442,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[30]}, + Columns: []*schema.Column{GroupsColumns[25]}, }, }, } @@ -742,6 +739,10 @@ var ( {Name: "model", Type: field.TypeString, Size: 100}, {Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, + {Name: "channel_id", Type: field.TypeInt64, Nullable: true}, + {Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500}, + {Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50}, + {Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -764,7 +765,6 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, - {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, @@ -781,31 +781,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -814,32 +814,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_model", @@ -859,17 +859,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]}, }, }, } @@ -890,8 +890,6 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, - {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, - {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 42c63c2e..d206039a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8230,16 +8230,6 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 - sora_image_price_360 *float64 - addsora_image_price_360 *float64 - sora_image_price_540 *float64 - addsora_image_price_540 *float64 - sora_video_price_per_request *float64 - addsora_video_price_per_request *float64 - sora_video_price_per_request_hd *float64 - addsora_video_price_per_request_hd *float64 - sora_storage_quota_bytes *int64 - addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -8253,6 +8243,8 @@ type GroupMutation struct { sort_order *int addsort_order *int allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool default_mapped_model *string clearedFields map[string]struct{} api_keys map[int64]struct{} @@ -9258,342 +9250,6 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (m *GroupMutation) SetSoraImagePrice360(f float64) { - m.sora_image_price_360 = &f - m.addsora_image_price_360 = nil -} - -// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. -func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { - v := m.sora_image_price_360 - if v == nil { - return - } - return *v, true -} - -// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) - } - return oldValue.SoraImagePrice360, nil -} - -// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. -func (m *GroupMutation) AddSoraImagePrice360(f float64) { - if m.addsora_image_price_360 != nil { - *m.addsora_image_price_360 += f - } else { - m.addsora_image_price_360 = &f - } -} - -// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. -func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { - v := m.addsora_image_price_360 - if v == nil { - return - } - return *v, true -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (m *GroupMutation) ClearSoraImagePrice360() { - m.sora_image_price_360 = nil - m.addsora_image_price_360 = nil - m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} -} - -// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. -func (m *GroupMutation) SoraImagePrice360Cleared() bool { - _, ok := m.clearedFields[group.FieldSoraImagePrice360] - return ok -} - -// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. -func (m *GroupMutation) ResetSoraImagePrice360() { - m.sora_image_price_360 = nil - m.addsora_image_price_360 = nil - delete(m.clearedFields, group.FieldSoraImagePrice360) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (m *GroupMutation) SetSoraImagePrice540(f float64) { - m.sora_image_price_540 = &f - m.addsora_image_price_540 = nil -} - -// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. -func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { - v := m.sora_image_price_540 - if v == nil { - return - } - return *v, true -} - -// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) - } - return oldValue.SoraImagePrice540, nil -} - -// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. -func (m *GroupMutation) AddSoraImagePrice540(f float64) { - if m.addsora_image_price_540 != nil { - *m.addsora_image_price_540 += f - } else { - m.addsora_image_price_540 = &f - } -} - -// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. -func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { - v := m.addsora_image_price_540 - if v == nil { - return - } - return *v, true -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (m *GroupMutation) ClearSoraImagePrice540() { - m.sora_image_price_540 = nil - m.addsora_image_price_540 = nil - m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} -} - -// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. -func (m *GroupMutation) SoraImagePrice540Cleared() bool { - _, ok := m.clearedFields[group.FieldSoraImagePrice540] - return ok -} - -// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. -func (m *GroupMutation) ResetSoraImagePrice540() { - m.sora_image_price_540 = nil - m.addsora_image_price_540 = nil - delete(m.clearedFields, group.FieldSoraImagePrice540) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { - m.sora_video_price_per_request = &f - m.addsora_video_price_per_request = nil -} - -// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. -func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { - v := m.sora_video_price_per_request - if v == nil { - return - } - return *v, true -} - -// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) - } - return oldValue.SoraVideoPricePerRequest, nil -} - -// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. -func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { - if m.addsora_video_price_per_request != nil { - *m.addsora_video_price_per_request += f - } else { - m.addsora_video_price_per_request = &f - } -} - -// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. -func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { - v := m.addsora_video_price_per_request - if v == nil { - return - } - return *v, true -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (m *GroupMutation) ClearSoraVideoPricePerRequest() { - m.sora_video_price_per_request = nil - m.addsora_video_price_per_request = nil - m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} -} - -// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. -func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { - _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] - return ok -} - -// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. -func (m *GroupMutation) ResetSoraVideoPricePerRequest() { - m.sora_video_price_per_request = nil - m.addsora_video_price_per_request = nil - delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { - m.sora_video_price_per_request_hd = &f - m.addsora_video_price_per_request_hd = nil -} - -// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. -func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { - v := m.sora_video_price_per_request_hd - if v == nil { - return - } - return *v, true -} - -// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) - } - return oldValue.SoraVideoPricePerRequestHd, nil -} - -// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { - if m.addsora_video_price_per_request_hd != nil { - *m.addsora_video_price_per_request_hd += f - } else { - m.addsora_video_price_per_request_hd = &f - } -} - -// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. -func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { - v := m.addsora_video_price_per_request_hd - if v == nil { - return - } - return *v, true -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { - m.sora_video_price_per_request_hd = nil - m.addsora_video_price_per_request_hd = nil - m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} -} - -// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. -func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { - _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] - return ok -} - -// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { - m.sora_video_price_per_request_hd = nil - m.addsora_video_price_per_request_hd = nil - delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { - m.sora_storage_quota_bytes = &i - m.addsora_storage_quota_bytes = nil -} - -// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. -func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { - v := m.sora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) - } - return oldValue.SoraStorageQuotaBytes, nil -} - -// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. -func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { - if m.addsora_storage_quota_bytes != nil { - *m.addsora_storage_quota_bytes += i - } else { - m.addsora_storage_quota_bytes = &i - } -} - -// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. -func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { - v := m.addsora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. -func (m *GroupMutation) ResetSoraStorageQuotaBytes() { - m.sora_storage_quota_bytes = nil - m.addsora_storage_quota_bytes = nil -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -10034,6 +9690,78 @@ func (m *GroupMutation) ResetAllowMessagesDispatch() { m.allow_messages_dispatch = nil } +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b +} + +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only + if v == nil { + return + } + return *v, true +} + +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + } + return oldValue.RequireOauthOnly, nil +} + +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b +} + +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set + if v == nil { + return + } + return *v, true +} + +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + } + return oldValue.RequirePrivacySet, nil +} + +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil +} + // SetDefaultMappedModel sets the "default_mapped_model" field. func (m *GroupMutation) SetDefaultMappedModel(s string) { m.default_mapped_model = &s @@ -10428,7 +10156,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 29) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10480,21 +10208,6 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } - if m.sora_image_price_360 != nil { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.sora_image_price_540 != nil { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.sora_video_price_per_request != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.sora_video_price_per_request_hd != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } - if m.sora_storage_quota_bytes != nil { - fields = append(fields, group.FieldSoraStorageQuotaBytes) - } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -10522,6 +10235,12 @@ func (m *GroupMutation) Fields() []string { if m.allow_messages_dispatch != nil { fields = append(fields, group.FieldAllowMessagesDispatch) } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } @@ -10567,16 +10286,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() - case group.FieldSoraImagePrice360: - return m.SoraImagePrice360() - case group.FieldSoraImagePrice540: - return m.SoraImagePrice540() - case group.FieldSoraVideoPricePerRequest: - return m.SoraVideoPricePerRequest() - case group.FieldSoraVideoPricePerRequestHd: - return m.SoraVideoPricePerRequestHd() - case group.FieldSoraStorageQuotaBytes: - return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -10595,6 +10304,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SortOrder() case group.FieldAllowMessagesDispatch: return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() } @@ -10640,16 +10353,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) - case group.FieldSoraImagePrice360: - return m.OldSoraImagePrice360(ctx) - case group.FieldSoraImagePrice540: - return m.OldSoraImagePrice540(ctx) - case group.FieldSoraVideoPricePerRequest: - return m.OldSoraVideoPricePerRequest(ctx) - case group.FieldSoraVideoPricePerRequestHd: - return m.OldSoraVideoPricePerRequestHd(ctx) - case group.FieldSoraStorageQuotaBytes: - return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -10668,6 +10371,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSortOrder(ctx) case group.FieldAllowMessagesDispatch: return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) } @@ -10798,41 +10505,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil - case group.FieldSoraImagePrice360: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraImagePrice360(v) - return nil - case group.FieldSoraImagePrice540: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraImagePrice540(v) - return nil - case group.FieldSoraVideoPricePerRequest: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraVideoPricePerRequest(v) - return nil - case group.FieldSoraVideoPricePerRequestHd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraVideoPricePerRequestHd(v) - return nil - case group.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageQuotaBytes(v) - return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -10896,6 +10568,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetAllowMessagesDispatch(v) return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil case group.FieldDefaultMappedModel: v, ok := value.(string) if !ok { @@ -10935,21 +10621,6 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } - if m.addsora_image_price_360 != nil { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.addsora_image_price_540 != nil { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.addsora_video_price_per_request != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.addsora_video_price_per_request_hd != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } - if m.addsora_storage_quota_bytes != nil { - fields = append(fields, group.FieldSoraStorageQuotaBytes) - } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -10983,16 +10654,6 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() - case group.FieldSoraImagePrice360: - return m.AddedSoraImagePrice360() - case group.FieldSoraImagePrice540: - return m.AddedSoraImagePrice540() - case group.FieldSoraVideoPricePerRequest: - return m.AddedSoraVideoPricePerRequest() - case group.FieldSoraVideoPricePerRequestHd: - return m.AddedSoraVideoPricePerRequestHd() - case group.FieldSoraStorageQuotaBytes: - return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -11064,41 +10725,6 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil - case group.FieldSoraImagePrice360: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraImagePrice360(v) - return nil - case group.FieldSoraImagePrice540: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraImagePrice540(v) - return nil - case group.FieldSoraVideoPricePerRequest: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraVideoPricePerRequest(v) - return nil - case group.FieldSoraVideoPricePerRequestHd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraVideoPricePerRequestHd(v) - return nil - case group.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageQuotaBytes(v) - return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -11152,18 +10778,6 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } - if m.FieldCleared(group.FieldSoraImagePrice360) { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.FieldCleared(group.FieldSoraImagePrice540) { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -11211,18 +10825,6 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil - case group.FieldSoraImagePrice360: - m.ClearSoraImagePrice360() - return nil - case group.FieldSoraImagePrice540: - m.ClearSoraImagePrice540() - return nil - case group.FieldSoraVideoPricePerRequest: - m.ClearSoraVideoPricePerRequest() - return nil - case group.FieldSoraVideoPricePerRequestHd: - m.ClearSoraVideoPricePerRequestHd() - return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -11291,21 +10893,6 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil - case group.FieldSoraImagePrice360: - m.ResetSoraImagePrice360() - return nil - case group.FieldSoraImagePrice540: - m.ResetSoraImagePrice540() - return nil - case group.FieldSoraVideoPricePerRequest: - m.ResetSoraVideoPricePerRequest() - return nil - case group.FieldSoraVideoPricePerRequestHd: - m.ResetSoraVideoPricePerRequestHd() - return nil - case group.FieldSoraStorageQuotaBytes: - m.ResetSoraStorageQuotaBytes() - return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -11333,6 +10920,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldAllowMessagesDispatch: m.ResetAllowMessagesDispatch() return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil @@ -19617,6 +19210,11 @@ type UsageLogMutation struct { model *string requested_model *string upstream_model *string + channel_id *int64 + addchannel_id *int64 + model_mapping_chain *string + billing_tier *string + billing_mode *string input_tokens *int addinput_tokens *int output_tokens *int @@ -19657,7 +19255,6 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string - media_type *string cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} @@ -20052,6 +19649,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() { delete(m.clearedFields, usagelog.FieldUpstreamModel) } +// SetChannelID sets the "channel_id" field. +func (m *UsageLogMutation) SetChannelID(i int64) { + m.channel_id = &i + m.addchannel_id = nil +} + +// ChannelID returns the value of the "channel_id" field in the mutation. +func (m *UsageLogMutation) ChannelID() (r int64, exists bool) { + v := m.channel_id + if v == nil { + return + } + return *v, true +} + +// OldChannelID returns the old "channel_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelID: %w", err) + } + return oldValue.ChannelID, nil +} + +// AddChannelID adds i to the "channel_id" field. +func (m *UsageLogMutation) AddChannelID(i int64) { + if m.addchannel_id != nil { + *m.addchannel_id += i + } else { + m.addchannel_id = &i + } +} + +// AddedChannelID returns the value that was added to the "channel_id" field in this mutation. +func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) { + v := m.addchannel_id + if v == nil { + return + } + return *v, true +} + +// ClearChannelID clears the value of the "channel_id" field. +func (m *UsageLogMutation) ClearChannelID() { + m.channel_id = nil + m.addchannel_id = nil + m.clearedFields[usagelog.FieldChannelID] = struct{}{} +} + +// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation. +func (m *UsageLogMutation) ChannelIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldChannelID] + return ok +} + +// ResetChannelID resets all changes to the "channel_id" field. +func (m *UsageLogMutation) ResetChannelID() { + m.channel_id = nil + m.addchannel_id = nil + delete(m.clearedFields, usagelog.FieldChannelID) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (m *UsageLogMutation) SetModelMappingChain(s string) { + m.model_mapping_chain = &s +} + +// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation. +func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) { + v := m.model_mapping_chain + if v == nil { + return + } + return *v, true +} + +// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelMappingChain requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err) + } + return oldValue.ModelMappingChain, nil +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (m *UsageLogMutation) ClearModelMappingChain() { + m.model_mapping_chain = nil + m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{} +} + +// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation. +func (m *UsageLogMutation) ModelMappingChainCleared() bool { + _, ok := m.clearedFields[usagelog.FieldModelMappingChain] + return ok +} + +// ResetModelMappingChain resets all changes to the "model_mapping_chain" field. +func (m *UsageLogMutation) ResetModelMappingChain() { + m.model_mapping_chain = nil + delete(m.clearedFields, usagelog.FieldModelMappingChain) +} + +// SetBillingTier sets the "billing_tier" field. +func (m *UsageLogMutation) SetBillingTier(s string) { + m.billing_tier = &s +} + +// BillingTier returns the value of the "billing_tier" field in the mutation. +func (m *UsageLogMutation) BillingTier() (r string, exists bool) { + v := m.billing_tier + if v == nil { + return + } + return *v, true +} + +// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingTier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingTier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingTier: %w", err) + } + return oldValue.BillingTier, nil +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (m *UsageLogMutation) ClearBillingTier() { + m.billing_tier = nil + m.clearedFields[usagelog.FieldBillingTier] = struct{}{} +} + +// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation. +func (m *UsageLogMutation) BillingTierCleared() bool { + _, ok := m.clearedFields[usagelog.FieldBillingTier] + return ok +} + +// ResetBillingTier resets all changes to the "billing_tier" field. +func (m *UsageLogMutation) ResetBillingTier() { + m.billing_tier = nil + delete(m.clearedFields, usagelog.FieldBillingTier) +} + +// SetBillingMode sets the "billing_mode" field. +func (m *UsageLogMutation) SetBillingMode(s string) { + m.billing_mode = &s +} + +// BillingMode returns the value of the "billing_mode" field in the mutation. +func (m *UsageLogMutation) BillingMode() (r string, exists bool) { + v := m.billing_mode + if v == nil { + return + } + return *v, true +} + +// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingMode: %w", err) + } + return oldValue.BillingMode, nil +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (m *UsageLogMutation) ClearBillingMode() { + m.billing_mode = nil + m.clearedFields[usagelog.FieldBillingMode] = struct{}{} +} + +// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation. +func (m *UsageLogMutation) BillingModeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldBillingMode] + return ok +} + +// ResetBillingMode resets all changes to the "billing_mode" field. +func (m *UsageLogMutation) ResetBillingMode() { + m.billing_mode = nil + delete(m.clearedFields, usagelog.FieldBillingMode) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -21383,55 +21197,6 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } -// SetMediaType sets the "media_type" field. -func (m *UsageLogMutation) SetMediaType(s string) { - m.media_type = &s -} - -// MediaType returns the value of the "media_type" field in the mutation. -func (m *UsageLogMutation) MediaType() (r string, exists bool) { - v := m.media_type - if v == nil { - return - } - return *v, true -} - -// OldMediaType returns the old "media_type" field's value of the UsageLog entity. -// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMediaType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMediaType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMediaType: %w", err) - } - return oldValue.MediaType, nil -} - -// ClearMediaType clears the value of the "media_type" field. -func (m *UsageLogMutation) ClearMediaType() { - m.media_type = nil - m.clearedFields[usagelog.FieldMediaType] = struct{}{} -} - -// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. -func (m *UsageLogMutation) MediaTypeCleared() bool { - _, ok := m.clearedFields[usagelog.FieldMediaType] - return ok -} - -// ResetMediaType resets all changes to the "media_type" field. -func (m *UsageLogMutation) ResetMediaType() { - m.media_type = nil - delete(m.clearedFields, usagelog.FieldMediaType) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { m.cache_ttl_overridden = &b @@ -21673,7 +21438,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 34) + fields := make([]string, 0, 37) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -21695,6 +21460,18 @@ func (m *UsageLogMutation) Fields() []string { if m.upstream_model != nil { fields = append(fields, usagelog.FieldUpstreamModel) } + if m.channel_id != nil { + fields = append(fields, usagelog.FieldChannelID) + } + if m.model_mapping_chain != nil { + fields = append(fields, usagelog.FieldModelMappingChain) + } + if m.billing_tier != nil { + fields = append(fields, usagelog.FieldBillingTier) + } + if m.billing_mode != nil { + fields = append(fields, usagelog.FieldBillingMode) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -21767,9 +21544,6 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } - if m.media_type != nil { - fields = append(fields, usagelog.FieldMediaType) - } if m.cache_ttl_overridden != nil { fields = append(fields, usagelog.FieldCacheTTLOverridden) } @@ -21798,6 +21572,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestedModel() case usagelog.FieldUpstreamModel: return m.UpstreamModel() + case usagelog.FieldChannelID: + return m.ChannelID() + case usagelog.FieldModelMappingChain: + return m.ModelMappingChain() + case usagelog.FieldBillingTier: + return m.BillingTier() + case usagelog.FieldBillingMode: + return m.BillingMode() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -21846,8 +21628,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() - case usagelog.FieldMediaType: - return m.MediaType() case usagelog.FieldCacheTTLOverridden: return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: @@ -21875,6 +21655,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestedModel(ctx) case usagelog.FieldUpstreamModel: return m.OldUpstreamModel(ctx) + case usagelog.FieldChannelID: + return m.OldChannelID(ctx) + case usagelog.FieldModelMappingChain: + return m.OldModelMappingChain(ctx) + case usagelog.FieldBillingTier: + return m.OldBillingTier(ctx) + case usagelog.FieldBillingMode: + return m.OldBillingMode(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -21923,8 +21711,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) - case usagelog.FieldMediaType: - return m.OldMediaType(ctx) case usagelog.FieldCacheTTLOverridden: return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: @@ -21987,6 +21773,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetUpstreamModel(v) return nil + case usagelog.FieldChannelID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelID(v) + return nil + case usagelog.FieldModelMappingChain: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelMappingChain(v) + return nil + case usagelog.FieldBillingTier: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingTier(v) + return nil + case usagelog.FieldBillingMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingMode(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -22155,13 +21969,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil - case usagelog.FieldMediaType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMediaType(v) - return nil case usagelog.FieldCacheTTLOverridden: v, ok := value.(bool) if !ok { @@ -22184,6 +21991,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *UsageLogMutation) AddedFields() []string { var fields []string + if m.addchannel_id != nil { + fields = append(fields, usagelog.FieldChannelID) + } if m.addinput_tokens != nil { fields = append(fields, usagelog.FieldInputTokens) } @@ -22246,6 +22056,8 @@ func (m *UsageLogMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { switch name { + case usagelog.FieldChannelID: + return m.AddedChannelID() case usagelog.FieldInputTokens: return m.AddedInputTokens() case usagelog.FieldOutputTokens: @@ -22291,6 +22103,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *UsageLogMutation) AddField(name string, value ent.Value) error { switch name { + case usagelog.FieldChannelID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddChannelID(v) + return nil case usagelog.FieldInputTokens: v, ok := value.(int) if !ok { @@ -22431,6 +22250,18 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldUpstreamModel) { fields = append(fields, usagelog.FieldUpstreamModel) } + if m.FieldCleared(usagelog.FieldChannelID) { + fields = append(fields, usagelog.FieldChannelID) + } + if m.FieldCleared(usagelog.FieldModelMappingChain) { + fields = append(fields, usagelog.FieldModelMappingChain) + } + if m.FieldCleared(usagelog.FieldBillingTier) { + fields = append(fields, usagelog.FieldBillingTier) + } + if m.FieldCleared(usagelog.FieldBillingMode) { + fields = append(fields, usagelog.FieldBillingMode) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -22455,9 +22286,6 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } - if m.FieldCleared(usagelog.FieldMediaType) { - fields = append(fields, usagelog.FieldMediaType) - } return fields } @@ -22478,6 +22306,18 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldUpstreamModel: m.ClearUpstreamModel() return nil + case usagelog.FieldChannelID: + m.ClearChannelID() + return nil + case usagelog.FieldModelMappingChain: + m.ClearModelMappingChain() + return nil + case usagelog.FieldBillingTier: + m.ClearBillingTier() + return nil + case usagelog.FieldBillingMode: + m.ClearBillingMode() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -22502,9 +22342,6 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil - case usagelog.FieldMediaType: - m.ClearMediaType() - return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -22534,6 +22371,18 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldUpstreamModel: m.ResetUpstreamModel() return nil + case usagelog.FieldChannelID: + m.ResetChannelID() + return nil + case usagelog.FieldModelMappingChain: + m.ResetModelMappingChain() + return nil + case usagelog.FieldBillingTier: + m.ResetBillingTier() + return nil + case usagelog.FieldBillingMode: + m.ResetBillingMode() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil @@ -22606,9 +22455,6 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil - case usagelog.FieldMediaType: - m.ResetMediaType() - return nil case usagelog.FieldCacheTTLOverridden: m.ResetCacheTTLOverridden() return nil @@ -22787,10 +22633,6 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time - sora_storage_quota_bytes *int64 - addsora_storage_quota_bytes *int64 - sora_storage_used_bytes *int64 - addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -23505,118 +23347,6 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { - m.sora_storage_quota_bytes = &i - m.addsora_storage_quota_bytes = nil -} - -// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. -func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { - v := m.sora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. -// If the User object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) - } - return oldValue.SoraStorageQuotaBytes, nil -} - -// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. -func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { - if m.addsora_storage_quota_bytes != nil { - *m.addsora_storage_quota_bytes += i - } else { - m.addsora_storage_quota_bytes = &i - } -} - -// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. -func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { - v := m.addsora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. -func (m *UserMutation) ResetSoraStorageQuotaBytes() { - m.sora_storage_quota_bytes = nil - m.addsora_storage_quota_bytes = nil -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { - m.sora_storage_used_bytes = &i - m.addsora_storage_used_bytes = nil -} - -// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. -func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { - v := m.sora_storage_used_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. -// If the User object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) - } - return oldValue.SoraStorageUsedBytes, nil -} - -// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. -func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { - if m.addsora_storage_used_bytes != nil { - *m.addsora_storage_used_bytes += i - } else { - m.addsora_storage_used_bytes = &i - } -} - -// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. -func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { - v := m.addsora_storage_used_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. -func (m *UserMutation) ResetSoraStorageUsedBytes() { - m.sora_storage_used_bytes = nil - m.addsora_storage_used_bytes = nil -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -24137,7 +23867,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 16) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -24180,12 +23910,6 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } - if m.sora_storage_quota_bytes != nil { - fields = append(fields, user.FieldSoraStorageQuotaBytes) - } - if m.sora_storage_used_bytes != nil { - fields = append(fields, user.FieldSoraStorageUsedBytes) - } return fields } @@ -24222,10 +23946,6 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() - case user.FieldSoraStorageQuotaBytes: - return m.SoraStorageQuotaBytes() - case user.FieldSoraStorageUsedBytes: - return m.SoraStorageUsedBytes() } return nil, false } @@ -24263,10 +23983,6 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) - case user.FieldSoraStorageQuotaBytes: - return m.OldSoraStorageQuotaBytes(ctx) - case user.FieldSoraStorageUsedBytes: - return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -24374,20 +24090,6 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil - case user.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageQuotaBytes(v) - return nil - case user.FieldSoraStorageUsedBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageUsedBytes(v) - return nil } return fmt.Errorf("unknown User field %s", name) } @@ -24402,12 +24104,6 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } - if m.addsora_storage_quota_bytes != nil { - fields = append(fields, user.FieldSoraStorageQuotaBytes) - } - if m.addsora_storage_used_bytes != nil { - fields = append(fields, user.FieldSoraStorageUsedBytes) - } return fields } @@ -24420,10 +24116,6 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() - case user.FieldSoraStorageQuotaBytes: - return m.AddedSoraStorageQuotaBytes() - case user.FieldSoraStorageUsedBytes: - return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -24447,20 +24139,6 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil - case user.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageQuotaBytes(v) - return nil - case user.FieldSoraStorageUsedBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageUsedBytes(v) - return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -24551,12 +24229,6 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil - case user.FieldSoraStorageQuotaBytes: - m.ResetSoraStorageQuotaBytes() - return nil - case user.FieldSoraStorageUsedBytes: - m.ResetSoraStorageUsedBytes() - return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ca95f13f..803b7bc2 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -430,36 +430,40 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) - // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. - groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() - // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. - group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[19].Descriptor() + groupDescClaudeCodeOnly := groupFields[14].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[23].Descriptor() + groupDescModelRoutingEnabled := groupFields[18].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[24].Descriptor() + groupDescMcpXMLInject := groupFields[19].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[25].Descriptor() + groupDescSupportedModelScopes := groupFields[20].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[26].Descriptor() + groupDescSortOrder := groupFields[21].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. - groupDescAllowMessagesDispatch := groupFields[27].Descriptor() + groupDescAllowMessagesDispatch := groupFields[22].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) + // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. + groupDescRequireOauthOnly := groupFields[23].Descriptor() + // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. + group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) + // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. + groupDescRequirePrivacySet := groupFields[24].Descriptor() + // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. + group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[28].Descriptor() + groupDescDefaultMappedModel := groupFields[25].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -867,92 +871,100 @@ func init() { usagelogDescUpstreamModel := usagelogFields[6].Descriptor() // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) + // usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field. + usagelogDescModelMappingChain := usagelogFields[8].Descriptor() + // usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save. + usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error) + // usagelogDescBillingTier is the schema descriptor for billing_tier field. + usagelogDescBillingTier := usagelogFields[9].Descriptor() + // usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save. + usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error) + // usagelogDescBillingMode is the schema descriptor for billing_mode field. + usagelogDescBillingMode := usagelogFields[10].Descriptor() + // usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save. + usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[9].Descriptor() + usagelogDescInputTokens := usagelogFields[13].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[10].Descriptor() + usagelogDescOutputTokens := usagelogFields[14].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[12].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[16].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[15].Descriptor() + usagelogDescInputCost := usagelogFields[19].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[16].Descriptor() + usagelogDescOutputCost := usagelogFields[20].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[17].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[21].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[18].Descriptor() + usagelogDescCacheReadCost := usagelogFields[22].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[19].Descriptor() + usagelogDescTotalCost := usagelogFields[23].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[20].Descriptor() + usagelogDescActualCost := usagelogFields[24].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[21].Descriptor() + usagelogDescRateMultiplier := usagelogFields[25].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[23].Descriptor() + usagelogDescBillingType := usagelogFields[27].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[24].Descriptor() + usagelogDescStream := usagelogFields[28].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[27].Descriptor() + usagelogDescUserAgent := usagelogFields[31].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[28].Descriptor() + usagelogDescIPAddress := usagelogFields[32].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[29].Descriptor() + usagelogDescImageCount := usagelogFields[33].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[30].Descriptor() + usagelogDescImageSize := usagelogFields[34].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) - // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[31].Descriptor() - // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. - usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[33].Descriptor() + usagelogDescCreatedAt := usagelogFields[36].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() @@ -1044,14 +1056,6 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) - // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. - userDescSoraStorageQuotaBytes := userFields[11].Descriptor() - // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. - user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) - // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. - userDescSoraStorageUsedBytes := userFields[12].Descriptor() - // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. - user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 0f5a7b14..0eb89c18 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - // Sora 按次计费配置(阶段 1) - field.Float("sora_image_price_360"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_image_price_540"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_video_price_per_request"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_video_price_per_request_hd"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - - // Sora 存储配额 - field.Int64("sora_storage_quota_bytes"). - Default(0), - // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). @@ -153,6 +131,12 @@ func (Group) Fields() []ent.Field { field.Bool("allow_messages_dispatch"). Default(false). Comment("是否允许 /v1/messages 调度到此 OpenAI 分组"), + field.Bool("require_oauth_only"). + Default(false). + Comment("仅允许非 apikey 类型账号关联到此分组"), + field.Bool("require_privacy_set"). + Default(false). + Comment("调度时仅允许 privacy 已成功设置的账号"), field.String("default_mapped_model"). MaxLen(100). Default(""). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 32c39e25..bd3ebfcc 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field { MaxLen(100). Optional(). Nillable(), + field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"), + field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"), + field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"), + field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"), field.Int64("group_id"). Optional(). Nillable(), @@ -130,12 +134,6 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), - // 媒体类型字段(sora 使用) - field.String("media_type"). - MaxLen(16). - Optional(). - Nillable(), - // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) field.Bool("cache_ttl_overridden"). Default(false), diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 0a3b5d9e..d443ef45 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,12 +72,6 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), - - // Sora 存储配额 - field.Int64("sora_storage_quota_bytes"). - Default(0), - field.Int64("sora_storage_used_bytes"). - Default(0), } } diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index fb4ee1c5..a8e0cc6c 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -36,6 +36,14 @@ type UsageLog struct { RequestedModel *string `json:"requested_model,omitempty"` // UpstreamModel holds the value of the "upstream_model" field. UpstreamModel *string `json:"upstream_model,omitempty"` + // 渠道 ID + ChannelID *int64 `json:"channel_id,omitempty"` + // 模型映射链 + ModelMappingChain *string `json:"model_mapping_chain,omitempty"` + // 计费层级标签 + BillingTier *string `json:"billing_tier,omitempty"` + // 计费模式:token/per_request/image + BillingMode *string `json:"billing_mode,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -84,8 +92,6 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` - // MediaType holds the value of the "media_type" field. - MediaType *string `json:"media_type,omitempty"` // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. @@ -177,9 +183,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) - case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: + case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -248,6 +254,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.UpstreamModel = new(string) *_m.UpstreamModel = value.String } + case usagelog.FieldChannelID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field channel_id", values[i]) + } else if value.Valid { + _m.ChannelID = new(int64) + *_m.ChannelID = value.Int64 + } + case usagelog.FieldModelMappingChain: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i]) + } else if value.Valid { + _m.ModelMappingChain = new(string) + *_m.ModelMappingChain = value.String + } + case usagelog.FieldBillingTier: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field billing_tier", values[i]) + } else if value.Valid { + _m.BillingTier = new(string) + *_m.BillingTier = value.String + } + case usagelog.FieldBillingMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field billing_mode", values[i]) + } else if value.Valid { + _m.BillingMode = new(string) + *_m.BillingMode = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -400,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } - case usagelog.FieldMediaType: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field media_type", values[i]) - } else if value.Valid { - _m.MediaType = new(string) - *_m.MediaType = value.String - } case usagelog.FieldCacheTTLOverridden: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) @@ -505,6 +532,26 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ChannelID; v != nil { + builder.WriteString("channel_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ModelMappingChain; v != nil { + builder.WriteString("model_mapping_chain=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.BillingTier; v != nil { + builder.WriteString("billing_tier=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.BillingMode; v != nil { + builder.WriteString("billing_mode=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) @@ -593,11 +640,6 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") - if v := _m.MediaType; v != nil { - builder.WriteString("media_type=") - builder.WriteString(*v) - } - builder.WriteString(", ") builder.WriteString("cache_ttl_overridden=") builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index b534f193..a7438e60 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -28,6 +28,14 @@ const ( FieldRequestedModel = "requested_model" // FieldUpstreamModel holds the string denoting the upstream_model field in the database. FieldUpstreamModel = "upstream_model" + // FieldChannelID holds the string denoting the channel_id field in the database. + FieldChannelID = "channel_id" + // FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database. + FieldModelMappingChain = "model_mapping_chain" + // FieldBillingTier holds the string denoting the billing_tier field in the database. + FieldBillingTier = "billing_tier" + // FieldBillingMode holds the string denoting the billing_mode field in the database. + FieldBillingMode = "billing_mode" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -76,8 +84,6 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" - // FieldMediaType holds the string denoting the media_type field in the database. - FieldMediaType = "media_type" // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. @@ -141,6 +147,10 @@ var Columns = []string{ FieldModel, FieldRequestedModel, FieldUpstreamModel, + FieldChannelID, + FieldModelMappingChain, + FieldBillingTier, + FieldBillingMode, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -165,7 +175,6 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, - FieldMediaType, FieldCacheTTLOverridden, FieldCreatedAt, } @@ -189,6 +198,12 @@ var ( RequestedModelValidator func(string) error // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. UpstreamModelValidator func(string) error + // ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save. + ModelMappingChainValidator func(string) error + // BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save. + BillingTierValidator func(string) error + // BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save. + BillingModeValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -227,8 +242,6 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error - // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. - MediaTypeValidator func(string) error // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. @@ -278,6 +291,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() } +// ByChannelID orders the results by the channel_id field. +func ByChannelID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelID, opts...).ToFunc() +} + +// ByModelMappingChain orders the results by the model_mapping_chain field. +func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc() +} + +// ByBillingTier orders the results by the billing_tier field. +func ByBillingTier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingTier, opts...).ToFunc() +} + +// ByBillingMode orders the results by the billing_mode field. +func ByBillingMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingMode, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() @@ -398,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } -// ByMediaType orders the results by the media_type field. -func ByMediaType(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldMediaType, opts...).ToFunc() -} - // ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index f95bceb7..b8439a03 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) } +// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ. +func ChannelID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v)) +} + +// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ. +func ModelMappingChain(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v)) +} + +// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ. +func BillingTier(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v)) +} + +// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ. +func BillingMode(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -210,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } -// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. -func MediaType(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) -} - // CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. func CacheTTLOverridden(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) @@ -565,6 +580,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) } +// ChannelIDEQ applies the EQ predicate on the "channel_id" field. +func ChannelIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v)) +} + +// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field. +func ChannelIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v)) +} + +// ChannelIDIn applies the In predicate on the "channel_id" field. +func ChannelIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...)) +} + +// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field. +func ChannelIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...)) +} + +// ChannelIDGT applies the GT predicate on the "channel_id" field. +func ChannelIDGT(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldChannelID, v)) +} + +// ChannelIDGTE applies the GTE predicate on the "channel_id" field. +func ChannelIDGTE(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v)) +} + +// ChannelIDLT applies the LT predicate on the "channel_id" field. +func ChannelIDLT(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldChannelID, v)) +} + +// ChannelIDLTE applies the LTE predicate on the "channel_id" field. +func ChannelIDLTE(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v)) +} + +// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field. +func ChannelIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldChannelID)) +} + +// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field. +func ChannelIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldChannelID)) +} + +// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field. +func ModelMappingChainEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v)) +} + +// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field. +func ModelMappingChainNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v)) +} + +// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field. +func ModelMappingChainIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...)) +} + +// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field. +func ModelMappingChainNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...)) +} + +// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field. +func ModelMappingChainGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v)) +} + +// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field. +func ModelMappingChainGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v)) +} + +// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field. +func ModelMappingChainLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v)) +} + +// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field. +func ModelMappingChainLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v)) +} + +// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field. +func ModelMappingChainContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v)) +} + +// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field. +func ModelMappingChainHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v)) +} + +// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field. +func ModelMappingChainHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v)) +} + +// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field. +func ModelMappingChainIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain)) +} + +// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field. +func ModelMappingChainNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain)) +} + +// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field. +func ModelMappingChainEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v)) +} + +// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field. +func ModelMappingChainContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v)) +} + +// BillingTierEQ applies the EQ predicate on the "billing_tier" field. +func BillingTierEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v)) +} + +// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field. +func BillingTierNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v)) +} + +// BillingTierIn applies the In predicate on the "billing_tier" field. +func BillingTierIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...)) +} + +// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field. +func BillingTierNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...)) +} + +// BillingTierGT applies the GT predicate on the "billing_tier" field. +func BillingTierGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v)) +} + +// BillingTierGTE applies the GTE predicate on the "billing_tier" field. +func BillingTierGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v)) +} + +// BillingTierLT applies the LT predicate on the "billing_tier" field. +func BillingTierLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v)) +} + +// BillingTierLTE applies the LTE predicate on the "billing_tier" field. +func BillingTierLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v)) +} + +// BillingTierContains applies the Contains predicate on the "billing_tier" field. +func BillingTierContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v)) +} + +// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field. +func BillingTierHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v)) +} + +// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field. +func BillingTierHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v)) +} + +// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field. +func BillingTierIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier)) +} + +// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field. +func BillingTierNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier)) +} + +// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field. +func BillingTierEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v)) +} + +// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field. +func BillingTierContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v)) +} + +// BillingModeEQ applies the EQ predicate on the "billing_mode" field. +func BillingModeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v)) +} + +// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field. +func BillingModeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v)) +} + +// BillingModeIn applies the In predicate on the "billing_mode" field. +func BillingModeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...)) +} + +// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field. +func BillingModeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...)) +} + +// BillingModeGT applies the GT predicate on the "billing_mode" field. +func BillingModeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v)) +} + +// BillingModeGTE applies the GTE predicate on the "billing_mode" field. +func BillingModeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v)) +} + +// BillingModeLT applies the LT predicate on the "billing_mode" field. +func BillingModeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v)) +} + +// BillingModeLTE applies the LTE predicate on the "billing_mode" field. +func BillingModeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v)) +} + +// BillingModeContains applies the Contains predicate on the "billing_mode" field. +func BillingModeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v)) +} + +// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field. +func BillingModeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v)) +} + +// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field. +func BillingModeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v)) +} + +// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field. +func BillingModeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode)) +} + +// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field. +func BillingModeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode)) +} + +// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field. +func BillingModeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v)) +} + +// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field. +func BillingModeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -1610,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } -// MediaTypeEQ applies the EQ predicate on the "media_type" field. -func MediaTypeEQ(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) -} - -// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. -func MediaTypeNEQ(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) -} - -// MediaTypeIn applies the In predicate on the "media_type" field. -func MediaTypeIn(vs ...string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) -} - -// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. -func MediaTypeNotIn(vs ...string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) -} - -// MediaTypeGT applies the GT predicate on the "media_type" field. -func MediaTypeGT(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) -} - -// MediaTypeGTE applies the GTE predicate on the "media_type" field. -func MediaTypeGTE(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) -} - -// MediaTypeLT applies the LT predicate on the "media_type" field. -func MediaTypeLT(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) -} - -// MediaTypeLTE applies the LTE predicate on the "media_type" field. -func MediaTypeLTE(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) -} - -// MediaTypeContains applies the Contains predicate on the "media_type" field. -func MediaTypeContains(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) -} - -// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. -func MediaTypeHasPrefix(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) -} - -// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. -func MediaTypeHasSuffix(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) -} - -// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. -func MediaTypeIsNil() predicate.UsageLog { - return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) -} - -// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. -func MediaTypeNotNil() predicate.UsageLog { - return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) -} - -// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. -func MediaTypeEqualFold(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) -} - -// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. -func MediaTypeContainsFold(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) -} - // CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index 6ae0bf59..fded364e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { return _c } +// SetChannelID sets the "channel_id" field. +func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate { + _c.mutation.SetChannelID(v) + return _c +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetChannelID(*v) + } + return _c +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate { + _c.mutation.SetModelMappingChain(v) + return _c +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate { + if v != nil { + _c.SetModelMappingChain(*v) + } + return _c +} + +// SetBillingTier sets the "billing_tier" field. +func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate { + _c.mutation.SetBillingTier(v) + return _c +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate { + if v != nil { + _c.SetBillingTier(*v) + } + return _c +} + +// SetBillingMode sets the "billing_mode" field. +func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate { + _c.mutation.SetBillingMode(v) + return _c +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate { + if v != nil { + _c.SetBillingMode(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -421,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } -// SetMediaType sets the "media_type" field. -func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { - _c.mutation.SetMediaType(v) - return _c -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { - if v != nil { - _c.SetMediaType(*v) - } - return _c -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { _c.mutation.SetCacheTTLOverridden(v) @@ -634,6 +676,21 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _c.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _c.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _c.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -697,11 +754,6 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _c.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _, ok := _c.mutation.CacheTTLOverridden(); !ok { return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} } @@ -760,6 +812,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _node.UpstreamModel = &value } + if value, ok := _c.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + _node.ChannelID = &value + } + if value, ok := _c.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + _node.ModelMappingChain = &value + } + if value, ok := _c.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + _node.BillingTier = &value + } + if value, ok := _c.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + _node.BillingMode = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -848,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } - if value, ok := _c.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - _node.MediaType = &value - } if value, ok := _c.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _node.CacheTTLOverridden = value @@ -1093,6 +1157,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { return u } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldChannelID, v) + return u +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldChannelID) + return u +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert { + u.Add(usagelog.FieldChannelID, v) + return u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert { + u.SetNull(usagelog.FieldChannelID) + return u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert { + u.Set(usagelog.FieldModelMappingChain, v) + return u +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldModelMappingChain) + return u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert { + u.SetNull(usagelog.FieldModelMappingChain) + return u +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert { + u.Set(usagelog.FieldBillingTier, v) + return u +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingTier) + return u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert { + u.SetNull(usagelog.FieldBillingTier) + return u +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert { + u.Set(usagelog.FieldBillingMode, v) + return u +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingMode) + return u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert { + u.SetNull(usagelog.FieldBillingMode) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1537,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { - u.Set(usagelog.FieldMediaType, v) - return u -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { - u.SetExcluded(usagelog.FieldMediaType) - return u -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { - u.SetNull(usagelog.FieldMediaType) - return u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { u.Set(usagelog.FieldCacheTTLOverridden, v) @@ -1724,6 +1848,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { }) } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetChannelID(v) + }) +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddChannelID(v) + }) +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateChannelID() + }) +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearChannelID() + }) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetModelMappingChain(v) + }) +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModelMappingChain() + }) +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearModelMappingChain() + }) +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingTier(v) + }) +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingTier() + }) +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingTier() + }) +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingMode(v) + }) +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingMode() + }) +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingMode() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2242,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.SetMediaType(v) - }) -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.UpdateMediaType() - }) -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.ClearMediaType() - }) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2600,6 +2794,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { }) } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetChannelID(v) + }) +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddChannelID(v) + }) +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateChannelID() + }) +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearChannelID() + }) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetModelMappingChain(v) + }) +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModelMappingChain() + }) +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearModelMappingChain() + }) +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingTier(v) + }) +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingTier() + }) +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingTier() + }) +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingMode(v) + }) +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingMode() + }) +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingMode() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { @@ -3118,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.SetMediaType(v) - }) -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.UpdateMediaType() - }) -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.ClearMediaType() - }) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 516407b9..bb5ac86c 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { return _u } +// SetChannelID sets the "channel_id" field. +func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate { + _u.mutation.ResetChannelID() + _u.mutation.SetChannelID(v) + return _u +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetChannelID(*v) + } + return _u +} + +// AddChannelID adds value to the "channel_id" field. +func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate { + _u.mutation.AddChannelID(v) + return _u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate { + _u.mutation.ClearChannelID() + return _u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate { + _u.mutation.SetModelMappingChain(v) + return _u +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate { + if v != nil { + _u.SetModelMappingChain(*v) + } + return _u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate { + _u.mutation.ClearModelMappingChain() + return _u +} + +// SetBillingTier sets the "billing_tier" field. +func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate { + _u.mutation.SetBillingTier(v) + return _u +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate { + if v != nil { + _u.SetBillingTier(*v) + } + return _u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate { + _u.mutation.ClearBillingTier() + return _u +} + +// SetBillingMode sets the "billing_mode" field. +func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate { + _u.mutation.SetBillingMode(v) + return _u +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate { + if v != nil { + _u.SetBillingMode(*v) + } + return _u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate { + _u.mutation.ClearBillingMode() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -652,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } -// SetMediaType sets the "media_type" field. -func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { - _u.mutation.SetMediaType(v) - return _u -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { - if v != nil { - _u.SetMediaType(*v) - } - return _u -} - -// ClearMediaType clears the value of the "media_type" field. -func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { - _u.mutation.ClearMediaType() - return _u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { _u.mutation.SetCacheTTLOverridden(v) @@ -795,6 +862,21 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _u.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _u.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _u.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -810,11 +892,6 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _u.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -857,6 +934,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.UpstreamModelCleared() { _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) } + if value, ok := _u.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedChannelID(); ok { + _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if _u.mutation.ChannelIDCleared() { + _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64) + } + if value, ok := _u.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + } + if _u.mutation.ModelMappingChainCleared() { + _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString) + } + if value, ok := _u.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + } + if _u.mutation.BillingTierCleared() { + _spec.ClearField(usagelog.FieldBillingTier, field.TypeString) + } + if value, ok := _u.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + } + if _u.mutation.BillingModeCleared() { + _spec.ClearField(usagelog.FieldBillingMode, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -995,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } - if value, ok := _u.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - } - if _u.mutation.MediaTypeCleared() { - _spec.ClearField(usagelog.FieldMediaType, field.TypeString) - } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } @@ -1279,6 +1377,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { return _u } +// SetChannelID sets the "channel_id" field. +func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne { + _u.mutation.ResetChannelID() + _u.mutation.SetChannelID(v) + return _u +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetChannelID(*v) + } + return _u +} + +// AddChannelID adds value to the "channel_id" field. +func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne { + _u.mutation.AddChannelID(v) + return _u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne { + _u.mutation.ClearChannelID() + return _u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne { + _u.mutation.SetModelMappingChain(v) + return _u +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetModelMappingChain(*v) + } + return _u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne { + _u.mutation.ClearModelMappingChain() + return _u +} + +// SetBillingTier sets the "billing_tier" field. +func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne { + _u.mutation.SetBillingTier(v) + return _u +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingTier(*v) + } + return _u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne { + _u.mutation.ClearBillingTier() + return _u +} + +// SetBillingMode sets the "billing_mode" field. +func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne { + _u.mutation.SetBillingMode(v) + return _u +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingMode(*v) + } + return _u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne { + _u.mutation.ClearBillingMode() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1789,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } -// SetMediaType sets the "media_type" field. -func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { - _u.mutation.SetMediaType(v) - return _u -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { - if v != nil { - _u.SetMediaType(*v) - } - return _u -} - -// ClearMediaType clears the value of the "media_type" field. -func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { - _u.mutation.ClearMediaType() - return _u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { _u.mutation.SetCacheTTLOverridden(v) @@ -1945,6 +2110,21 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _u.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _u.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _u.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -1960,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _u.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -2024,6 +2199,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.UpstreamModelCleared() { _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) } + if value, ok := _u.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedChannelID(); ok { + _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if _u.mutation.ChannelIDCleared() { + _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64) + } + if value, ok := _u.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + } + if _u.mutation.ModelMappingChainCleared() { + _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString) + } + if value, ok := _u.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + } + if _u.mutation.BillingTierCleared() { + _spec.ClearField(usagelog.FieldBillingTier, field.TypeString) + } + if value, ok := _u.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + } + if _u.mutation.BillingModeCleared() { + _spec.ClearField(usagelog.FieldBillingMode, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -2162,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } - if value, ok := _u.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - } - if _u.mutation.MediaTypeCleared() { - _spec.ClearField(usagelog.FieldMediaType, field.TypeString) - } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } diff --git a/backend/ent/user.go b/backend/ent/user.go index b3f933f6..2435aa1b 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,10 +45,6 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` - // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` - // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. - SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -181,7 +177,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: + case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -295,18 +291,6 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } - case user.FieldSoraStorageQuotaBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageQuotaBytes = value.Int64 - } - case user.FieldSoraStorageUsedBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageUsedBytes = value.Int64 - } default: _m.selectValues.Set(columns[i], values[i]) } @@ -440,12 +424,6 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } - builder.WriteString(", ") - builder.WriteString("sora_storage_quota_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) - builder.WriteString(", ") - builder.WriteString("sora_storage_used_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 155b9160..ae9418ff 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,10 +43,6 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" - // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. - FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" - // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. - FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -156,8 +152,6 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, - FieldSoraStorageQuotaBytes, - FieldSoraStorageUsedBytes, } var ( @@ -214,10 +208,6 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool - // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. - DefaultSoraStorageQuotaBytes int64 - // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. - DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -298,16 +288,6 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } -// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. -func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() -} - -// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. -func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() -} - // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index e26afcf3..1de61037 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,16 +125,6 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } -// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. -func SoraStorageQuotaBytes(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. -func SoraStorageUsedBytes(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) -} - // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -870,86 +860,6 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } -// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesEQ(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNEQ(v int64) predicate.User { - return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGT(v int64) predicate.User { - return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGTE(v int64) predicate.User { - return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLT(v int64) predicate.User { - return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLTE(v int64) predicate.User { - return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesEQ(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesNEQ(v int64) predicate.User { - return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) -} - -// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) -} - -// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesGT(v int64) predicate.User { - return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesGTE(v int64) predicate.User { - return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesLT(v int64) predicate.User { - return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesLTE(v int64) predicate.User { - return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) -} - // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index df0c6bcc..f862a580 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,34 +210,6 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { - _c.mutation.SetSoraStorageQuotaBytes(v) - return _c -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { - if v != nil { - _c.SetSoraStorageQuotaBytes(*v) - } - return _c -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { - _c.mutation.SetSoraStorageUsedBytes(v) - return _c -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { - if v != nil { - _c.SetSoraStorageUsedBytes(*v) - } - return _c -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -452,14 +424,6 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - v := user.DefaultSoraStorageQuotaBytes - _c.mutation.SetSoraStorageQuotaBytes(v) - } - if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { - v := user.DefaultSoraStorageUsedBytes - _c.mutation.SetSoraStorageUsedBytes(v) - } return nil } @@ -523,12 +487,6 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} - } - if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { - return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} - } return nil } @@ -612,14 +570,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } - if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - _node.SoraStorageQuotaBytes = value - } - if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - _node.SoraStorageUsedBytes = value - } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1006,42 +956,6 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { - u.Set(user.FieldSoraStorageQuotaBytes, v) - return u -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { - u.SetExcluded(user.FieldSoraStorageQuotaBytes) - return u -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { - u.Add(user.FieldSoraStorageQuotaBytes, v) - return u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { - u.Set(user.FieldSoraStorageUsedBytes, v) - return u -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { - u.SetExcluded(user.FieldSoraStorageUsedBytes) - return u -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { - u.Add(user.FieldSoraStorageUsedBytes, v) - return u -} - // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1304,48 +1218,6 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageUsedBytes(v) - }) -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageUsedBytes(v) - }) -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageUsedBytes() - }) -} - // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1774,48 +1646,6 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageUsedBytes(v) - }) -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageUsedBytes(v) - }) -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageUsedBytes() - }) -} - // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index f71f0cad..80222c92 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,48 +242,6 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { - _u.mutation.ResetSoraStorageUsedBytes() - _u.mutation.SetSoraStorageUsedBytes(v) - return _u -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { - if v != nil { - _u.SetSoraStorageUsedBytes(*v) - } - return _u -} - -// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. -func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { - _u.mutation.AddSoraStorageUsedBytes(v) - return _u -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -751,18 +709,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { - _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1406,48 +1352,6 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { - _u.mutation.ResetSoraStorageUsedBytes() - _u.mutation.SetSoraStorageUsedBytes(v) - return _u -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { - if v != nil { - _u.SetSoraStorageUsedBytes(*v) - } - return _u -} - -// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. -func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { - _u.mutation.AddSoraStorageUsedBytes(v) - return _u -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1945,18 +1849,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { - _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 8da26d70..5953be65 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -77,7 +77,6 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - Sora SoraConfig `mapstructure:"sora"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -197,8 +196,6 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` - // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) - SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -303,59 +300,6 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } -// SoraConfig 直连 Sora 配置 -type SoraConfig struct { - Client SoraClientConfig `mapstructure:"client"` - Storage SoraStorageConfig `mapstructure:"storage"` -} - -// SoraClientConfig 直连 Sora 客户端配置 -type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` - CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` -} - -// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 -type SoraCurlCFFISidecarConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - Impersonate string `mapstructure:"impersonate"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` - SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` -} - -// SoraStorageConfig 媒体存储配置 -type SoraStorageConfig struct { - Type string `mapstructure:"type"` - LocalPath string `mapstructure:"local_path"` - FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` - MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` - DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` - MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` - Debug bool `mapstructure:"debug"` - Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` -} - -// SoraStorageCleanupConfig 媒体清理配置 -type SoraStorageCleanupConfig struct { - Enabled bool `mapstructure:"enabled"` - Schedule string `mapstructure:"schedule"` - RetentionDays int `mapstructure:"retention_days"` -} - // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -428,24 +372,6 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` - // Sora 专用配置 - // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) - SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` - // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) - SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` - // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) - SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` - // SoraStreamMode: stream 强制策略(force/error) - SoraStreamMode string `mapstructure:"sora_stream_mode"` - // SoraModelFilters: 模型列表过滤配置 - SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` - // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key - SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` - // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) - SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` - // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) - SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` - // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -665,12 +591,6 @@ type GatewayUsageRecordConfig struct { AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` } -// SoraModelFiltersConfig Sora 模型过滤配置 -type SoraModelFiltersConfig struct { - // HidePromptEnhance 是否隐藏 prompt-enhance 模型 - HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` -} - // NodeTLSProxyConfig Node.js TLS 代理配置 // 通过本地 Node.js 进程转发 HTTPS 请求,利用原生 TLS 栈产生真实 JA3 指纹 type NodeTLSProxyConfig struct { @@ -1471,13 +1391,6 @@ func setDefaults() { viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) - viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) - viper.SetDefault("gateway.sora_request_timeout_seconds", 180) - viper.SetDefault("gateway.sora_stream_mode", "force") - viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) - viper.SetDefault("gateway.sora_media_require_api_key", true) - viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) @@ -1549,45 +1462,12 @@ func setDefaults() { _ = viper.BindEnv("gateway.node_tls_proxy.upstream_host", "GATEWAY_NODE_TLS_PROXY_UPSTREAM_HOST") viper.SetDefault("concurrency.ping_interval", 10) - // Sora 直连配置 - viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") - viper.SetDefault("sora.client.timeout_seconds", 120) - viper.SetDefault("sora.client.max_retries", 3) - viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) - viper.SetDefault("sora.client.poll_interval_seconds", 2) - viper.SetDefault("sora.client.max_poll_attempts", 600) - viper.SetDefault("sora.client.recent_task_limit", 50) - viper.SetDefault("sora.client.recent_task_limit_max", 200) - viper.SetDefault("sora.client.debug", false) - viper.SetDefault("sora.client.use_openai_token_provider", false) - viper.SetDefault("sora.client.headers", map[string]string{}) - viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - viper.SetDefault("sora.client.disable_tls_fingerprint", false) - viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) - viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") - viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") - viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) - viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) - viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) - - viper.SetDefault("sora.storage.type", "local") - viper.SetDefault("sora.storage.local_path", "") - viper.SetDefault("sora.storage.fallback_to_upstream", true) - viper.SetDefault("sora.storage.max_concurrent_downloads", 4) - viper.SetDefault("sora.storage.download_timeout_seconds", 120) - viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) - viper.SetDefault("sora.storage.debug", false) - viper.SetDefault("sora.storage.cleanup.enabled", true) - viper.SetDefault("sora.storage.cleanup.retention_days", 7) - viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") - // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 - viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -1963,86 +1843,6 @@ func (c *Config) Validate() error { if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") } - if c.Gateway.SoraMaxBodySize < 0 { - return fmt.Errorf("gateway.sora_max_body_size must be non-negative") - } - if c.Gateway.SoraStreamTimeoutSeconds < 0 { - return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") - } - if c.Gateway.SoraRequestTimeoutSeconds < 0 { - return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") - } - if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { - return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") - } - if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { - switch mode { - case "force", "error": - default: - return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") - } - } - if c.Sora.Client.TimeoutSeconds < 0 { - return fmt.Errorf("sora.client.timeout_seconds must be non-negative") - } - if c.Sora.Client.MaxRetries < 0 { - return fmt.Errorf("sora.client.max_retries must be non-negative") - } - if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { - return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") - } - if c.Sora.Client.PollIntervalSeconds < 0 { - return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") - } - if c.Sora.Client.MaxPollAttempts < 0 { - return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") - } - if c.Sora.Client.RecentTaskLimit < 0 { - return fmt.Errorf("sora.client.recent_task_limit must be non-negative") - } - if c.Sora.Client.RecentTaskLimitMax < 0 { - return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") - } - if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && - c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { - c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit - } - if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { - return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") - } - if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { - return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") - } - if !c.Sora.Client.CurlCFFISidecar.Enabled { - return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") - } - if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { - return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") - } - if c.Sora.Storage.MaxConcurrentDownloads < 0 { - return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") - } - if c.Sora.Storage.DownloadTimeoutSeconds < 0 { - return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") - } - if c.Sora.Storage.MaxDownloadBytes < 0 { - return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") - } - if c.Sora.Storage.Cleanup.Enabled { - if c.Sora.Storage.Cleanup.RetentionDays <= 0 { - return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") - } - if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { - return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") - } - } else { - if c.Sora.Storage.Cleanup.RetentionDays < 0 { - return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") - } - } - if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { - return fmt.Errorf("sora.storage.type must be 'local'") - } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index abb76549..2de5451e 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1554,94 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { } } -func TestSoraCurlCFFISidecarDefaults(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if !cfg.Sora.Client.CurlCFFISidecar.Enabled { - t.Fatalf("Sora curl_cffi sidecar should be enabled by default") - } - if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { - t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") - } - if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { - t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") - } - if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { - t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") - } - if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { - t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") - } - if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { - t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") - } -} - -func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.Enabled = false - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { - t.Fatalf("Validate() error = %v, want sidecar enabled error", err) - } -} - -func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { - t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) - } -} - -func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { - t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) - } -} - -func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { - t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) - } -} - func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 4e69ca02..429486c3 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,7 +22,6 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" - PlatformSora = "sora" ) // Account type constants diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 9d5f9de2..37b654ce 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -567,15 +567,15 @@ func defaultProxyName(name string) string { // enrichCredentialsFromIDToken performs best-effort extraction of user info fields // (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials. -// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently. +// Only applies to OpenAI OAuth accounts. Skips expired token errors silently. // Existing credential values are never overwritten — only missing fields are filled. func enrichCredentialsFromIDToken(item *DataAccount) { if item.Credentials == nil { return } - // Only enrich OpenAI/Sora OAuth accounts + // Only enrich OpenAI OAuth accounts platform := strings.ToLower(strings.TrimSpace(item.Platform)) - if platform != service.PlatformOpenAI && platform != service.PlatformSora { + if platform != service.PlatformOpenAI { return } if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index f860e666..e8750ac8 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -839,6 +839,7 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv if updateErr != nil { return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr) } + h.adminService.EnsureAntigravityPrivacy(ctx, updatedAccount) return updatedAccount, "missing_project_id_temporary", nil } @@ -1886,12 +1887,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } - // Handle Sora accounts - if account.Platform == service.PlatformSora { - response.Success(c, service.DefaultSoraModels(nil)) - return - } - // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 9759cef5..60d68913 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -380,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se {Target: "openai", Status: "pass", HTTPStatus: 401}, {Target: "anthropic", Status: "pass", HTTPStatus: 401}, {Target: "gemini", Status: "pass", HTTPStatus: 200}, - {Target: "sora", Status: "pass", HTTPStatus: 401}, }, }, nil } diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go new file mode 100644 index 00000000..563a27ce --- /dev/null +++ b/backend/internal/handler/admin/channel_handler.go @@ -0,0 +1,452 @@ +package admin + +import ( + "errors" + "fmt" + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ChannelHandler handles admin channel management +type ChannelHandler struct { + channelService *service.ChannelService + billingService *service.BillingService +} + +// NewChannelHandler creates a new admin channel handler +func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler { + return &ChannelHandler{channelService: channelService, billingService: billingService} +} + +// --- Request / Response types --- + +type createChannelRequest struct { + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels bool `json:"restrict_models"` +} + +type updateChannelRequest struct { + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels *bool `json:"restrict_models"` +} + +type channelModelPricingRequest struct { + Platform string `json:"platform" binding:"omitempty,max=50"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` +} + +type pricingIntervalRequest struct { + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +type channelResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type channelModelPricingResponse struct { + ID int64 `json:"id"` + Platform string `json:"platform"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []pricingIntervalResponse `json:"intervals"` +} + +type pricingIntervalResponse struct { + ID int64 `json:"id"` + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +func channelToResponse(ch *service.Channel) *channelResponse { + if ch == nil { + return nil + } + resp := &channelResponse{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + RestrictModels: ch.RestrictModels, + GroupIDs: ch.GroupIDs, + ModelMapping: ch.ModelMapping, + CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), + } + resp.BillingModelSource = ch.BillingModelSource + if resp.BillingModelSource == "" { + resp.BillingModelSource = service.BillingModelSourceChannelMapped + } + if resp.GroupIDs == nil { + resp.GroupIDs = []int64{} + } + if resp.ModelMapping == nil { + resp.ModelMapping = map[string]map[string]string{} + } + + resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing)) + for _, p := range ch.ModelPricing { + resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) + } + return resp +} + +func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse { + models := p.Models + if models == nil { + models = []string{} + } + billingMode := string(p.BillingMode) + if billingMode == "" { + billingMode = string(service.BillingModeToken) + } + platform := p.Platform + if platform == "" { + platform = service.PlatformAnthropic + } + intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) + for _, iv := range p.Intervals { + intervals = append(intervals, intervalToResponse(iv)) + } + return channelModelPricingResponse{ + ID: p.ID, + Platform: platform, + Models: models, + BillingMode: billingMode, + InputPrice: p.InputPrice, + OutputPrice: p.OutputPrice, + CacheWritePrice: p.CacheWritePrice, + CacheReadPrice: p.CacheReadPrice, + ImageOutputPrice: p.ImageOutputPrice, + PerRequestPrice: p.PerRequestPrice, + Intervals: intervals, + } +} + +func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse { + return pricingIntervalResponse{ + ID: iv.ID, + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + } +} + +func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing { + result := make([]service.ChannelModelPricing, 0, len(reqs)) + for _, r := range reqs { + billingMode := service.BillingMode(r.BillingMode) + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := r.Platform + if platform == "" { + platform = service.PlatformAnthropic + } + intervals := make([]service.PricingInterval, 0, len(r.Intervals)) + for _, iv := range r.Intervals { + intervals = append(intervals, service.PricingInterval{ + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + }) + } + result = append(result, service.ChannelModelPricing{ + Platform: platform, + Models: r.Models, + BillingMode: billingMode, + InputPrice: r.InputPrice, + OutputPrice: r.OutputPrice, + CacheWritePrice: r.CacheWritePrice, + CacheReadPrice: r.CacheReadPrice, + ImageOutputPrice: r.ImageOutputPrice, + PerRequestPrice: r.PerRequestPrice, + Intervals: intervals, + }) + } + return result +} + +// validatePricingBillingMode 校验计费配置 +func validatePricingBillingMode(pricing []service.ChannelModelPricing) error { + for _, p := range pricing { + // 按次/图片模式必须配置默认价格或区间 + if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + return errors.New("per-request price or intervals required for per_request/image billing mode") + } + } + // 校验价格不能为负 + if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil { + return err + } + if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil { + return err + } + if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil { + return err + } + // 校验 interval:至少有一个价格字段非空 + for _, iv := range p.Intervals { + if iv.InputPrice == nil && iv.OutputPrice == nil && + iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && + iv.PerRequestPrice == nil { + return fmt.Errorf("interval [%d, %s] has no price fields set for model %v", + iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models) + } + } + } + return nil +} + +func validatePriceNotNegative(field string, val *float64) error { + if val != nil && *val < 0 { + return fmt.Errorf("%s must be >= 0", field) + } + return nil +} + +func formatMaxTokens(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + +// --- Handlers --- + +// List handles listing channels with pagination +// GET /api/v1/admin/channels +func (h *ChannelHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + if len(search) > 100 { + search = search[:100] + } + + channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]*channelResponse, 0, len(channels)) + for i := range channels { + out = append(out, channelToResponse(&channels[i])) + } + response.Paginated(c, out, pag.Total, page, pageSize) +} + +// GetByID handles getting a channel by ID +// GET /api/v1/admin/channels/:id +func (h *ChannelHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + channel, err := h.channelService.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Create handles creating a new channel +// POST /api/v1/admin/channels +func (h *ChannelHandler) Create(c *gin.Context) { + var req createChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + pricing := pricingRequestToService(req.ModelPricing) + if err := validatePricingBillingMode(pricing); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricing, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Update handles updating a channel +// PUT /api/v1/admin/channels/:id +func (h *ChannelHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + var req updateChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + input := &service.UpdateChannelInput{ + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + } + if req.ModelPricing != nil { + pricing := pricingRequestToService(*req.ModelPricing) + if err := validatePricingBillingMode(pricing); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + input.ModelPricing = &pricing + } + + channel, err := h.channelService.Update(c.Request.Context(), id, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Delete handles deleting a channel +// DELETE /api/v1/admin/channels/:id +func (h *ChannelHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + if err := h.channelService.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Channel deleted successfully"}) +} + +// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充) +// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4 +func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { + model := strings.TrimSpace(c.Query("model")) + if model == "" { + response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required"). + WithMetadata(map[string]string{"param": "model"})) + return + } + + pricing, err := h.billingService.GetModelPricing(model) + if err != nil { + // 模型不在定价列表中 + response.Success(c, gin.H{"found": false}) + return + } + + response.Success(c, gin.H{ + "found": true, + "input_price": pricing.InputPricePerToken, + "output_price": pricing.OutputPricePerToken, + "cache_write_price": pricing.CacheCreationPricePerToken, + "cache_read_price": pricing.CacheReadPricePerToken, + "image_output_price": pricing.ImageOutputPricePerToken, + }) +} diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go new file mode 100644 index 00000000..6f6ea526 --- /dev/null +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -0,0 +1,502 @@ +//go:build unit + +package admin + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func float64Ptr(v float64) *float64 { return &v } +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// 1. channelToResponse +// --------------------------------------------------------------------------- + +func TestChannelToResponse_NilInput(t *testing.T) { + require.Nil(t, channelToResponse(nil)) +} + +func TestChannelToResponse_FullChannel(t *testing.T) { + now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 42, + Name: "test-channel", + Description: "desc", + Status: "active", + BillingModelSource: "upstream", + RestrictModels: true, + CreatedAt: now, + UpdatedAt: now.Add(time.Hour), + GroupIDs: []int64{1, 2, 3}, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 10, + Platform: "openai", + Models: []string{"gpt-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + PerRequestPrice: float64Ptr(0.5), + }, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-3-haiku": "claude-haiku-3"}, + }, + } + + resp := channelToResponse(ch) + require.NotNil(t, resp) + require.Equal(t, int64(42), resp.ID) + require.Equal(t, "test-channel", resp.Name) + require.Equal(t, "desc", resp.Description) + require.Equal(t, "active", resp.Status) + require.Equal(t, "upstream", resp.BillingModelSource) + require.True(t, resp.RestrictModels) + require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs) + require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt) + require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt) + + // model mapping + require.Len(t, resp.ModelMapping, 1) + require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"]) + + // pricing + require.Len(t, resp.ModelPricing, 1) + p := resp.ModelPricing[0] + require.Equal(t, int64(10), p.ID) + require.Equal(t, "openai", p.Platform) + require.Equal(t, []string{"gpt-4"}, p.Models) + require.Equal(t, "token", p.BillingMode) + require.Equal(t, float64Ptr(0.01), p.InputPrice) + require.Equal(t, float64Ptr(0.03), p.OutputPrice) + require.Equal(t, float64Ptr(0.005), p.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), p.CacheReadPrice) + require.Equal(t, float64Ptr(0.5), p.PerRequestPrice) + require.Empty(t, p.Intervals) +} + +func TestChannelToResponse_EmptyDefaults(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 1, + Name: "ch", + BillingModelSource: "", + CreatedAt: now, + UpdatedAt: now, + GroupIDs: nil, + ModelMapping: nil, + ModelPricing: []service.ChannelModelPricing{ + { + Platform: "", + BillingMode: "", + Models: []string{"m1"}, + }, + }, + } + + resp := channelToResponse(ch) + require.Equal(t, "channel_mapped", resp.BillingModelSource) + require.NotNil(t, resp.GroupIDs) + require.Empty(t, resp.GroupIDs) + require.NotNil(t, resp.ModelMapping) + require.Empty(t, resp.ModelMapping) + + require.Len(t, resp.ModelPricing, 1) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) +} + +func TestChannelToResponse_NilModels(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: nil, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + require.NotNil(t, resp.ModelPricing[0].Models) + require.Empty(t, resp.ModelPricing[0].Models) +} + +func TestChannelToResponse_WithIntervals(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: []string{"m1"}, + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + { + ID: 100, + MinTokens: 0, + MaxTokens: intPtr(1000), + TierLabel: "1K", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + ID: 101, + MinTokens: 1000, + MaxTokens: nil, + TierLabel: "unlimited", + SortOrder: 2, + }, + }, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + intervals := resp.ModelPricing[0].Intervals + require.Len(t, intervals, 2) + + iv0 := intervals[0] + require.Equal(t, int64(100), iv0.ID) + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(1000), iv0.MaxTokens) + require.Equal(t, "1K", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := intervals[1] + require.Equal(t, int64(101), iv1.ID) + require.Equal(t, 1000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "unlimited", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestChannelToResponse_MultipleEntries(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "multi", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 1, + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.003), + OutputPrice: float64Ptr(0.015), + }, + { + ID: 2, + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(1.0), + }, + { + ID: 3, + Platform: "gemini", + Models: []string{"gemini-2.5-pro"}, + BillingMode: service.BillingModeImage, + ImageOutputPrice: float64Ptr(0.05), + PerRequestPrice: float64Ptr(0.2), + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 3) + + require.Equal(t, int64(1), resp.ModelPricing[0].ID) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) + + require.Equal(t, int64(2), resp.ModelPricing[1].ID) + require.Equal(t, "openai", resp.ModelPricing[1].Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models) + require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode) + + require.Equal(t, int64(3), resp.ModelPricing[2].ID) + require.Equal(t, "gemini", resp.ModelPricing[2].Platform) + require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models) + require.Equal(t, "image", resp.ModelPricing[2].BillingMode) + require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice) +} + +// --------------------------------------------------------------------------- +// 2. pricingRequestToService +// --------------------------------------------------------------------------- + +func TestPricingRequestToService_Defaults(t *testing.T) { + tests := []struct { + name string + req channelModelPricingRequest + wantField string // which default field to check + wantValue string + }{ + { + name: "empty billing mode defaults to token", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + BillingMode: "", + }, + wantField: "BillingMode", + wantValue: string(service.BillingModeToken), + }, + { + name: "empty platform defaults to anthropic", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + Platform: "", + }, + wantField: "Platform", + wantValue: "anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{tt.req}) + require.Len(t, result, 1) + switch tt.wantField { + case "BillingMode": + require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode) + case "Platform": + require.Equal(t, tt.wantValue, result[0].Platform) + } + }) + } +} + +func TestPricingRequestToService_WithAllFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: "per_request", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + ImageOutputPrice: float64Ptr(0.04), + PerRequestPrice: float64Ptr(0.5), + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Equal(t, "openai", r.Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models) + require.Equal(t, service.BillingModePerRequest, r.BillingMode) + require.Equal(t, float64Ptr(0.01), r.InputPrice) + require.Equal(t, float64Ptr(0.03), r.OutputPrice) + require.Equal(t, float64Ptr(0.005), r.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), r.CacheReadPrice) + require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice) + require.Equal(t, float64Ptr(0.5), r.PerRequestPrice) +} + +func TestPricingRequestToService_WithIntervals(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "per_request", + Intervals: []pricingIntervalRequest{ + { + MinTokens: 0, + MaxTokens: intPtr(2000), + TierLabel: "small", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + MinTokens: 2000, + MaxTokens: nil, + TierLabel: "large", + SortOrder: 2, + }, + }, + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + require.Len(t, result[0].Intervals, 2) + + iv0 := result[0].Intervals[0] + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(2000), iv0.MaxTokens) + require.Equal(t, "small", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := result[0].Intervals[1] + require.Equal(t, 2000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "large", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestPricingRequestToService_EmptySlice(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{}) + require.NotNil(t, result) + require.Empty(t, result) +} + +func TestPricingRequestToService_NilPriceFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "token", + // all price fields are nil by default + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Nil(t, r.InputPrice) + require.Nil(t, r.OutputPrice) + require.Nil(t, r.CacheWritePrice) + require.Nil(t, r.CacheReadPrice) + require.Nil(t, r.ImageOutputPrice) + require.Nil(t, r.PerRequestPrice) +} + +// --------------------------------------------------------------------------- +// 3. validatePricingBillingMode +// --------------------------------------------------------------------------- + +func TestValidatePricingBillingMode(t *testing.T) { + tests := []struct { + name string + pricing []service.ChannelModelPricing + wantErr bool + }{ + { + name: "token mode - valid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeToken}, + }, + wantErr: false, + }, + { + name: "per_request with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + }, + wantErr: false, + }, + { + name: "per_request with intervals - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + {MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)}, + }, + }, + }, + wantErr: false, + }, + { + name: "per_request no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModePerRequest}, + }, + wantErr: true, + }, + { + name: "image with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeImage, + PerRequestPrice: float64Ptr(0.2), + }, + }, + wantErr: false, + }, + { + name: "image no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeImage}, + }, + wantErr: true, + }, + { + name: "empty list - valid", + pricing: []service.ChannelModelPricing{}, + wantErr: false, + }, + { + name: "mixed modes with invalid image - invalid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + }, + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + { + BillingMode: service.BillingModeImage, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePricingBillingMode(tt.pricing) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "per-request price or intervals required") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 2a214471..460f6357 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { dim.Endpoint = c.Query("endpoint") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") + // Additional filter conditions + if v := c.Query("user_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.UserID = id + } + } + if v := c.Query("api_key_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.APIKeyID = id + } + } + if v := c.Query("account_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.AccountID = id + } + } + if v := c.Query("request_type"); v != "" { + if rt, err := strconv.ParseInt(v, 10, 16); err == nil { + rtVal := int16(rt) + dim.RequestType = &rtVal + } + } + if v := c.Query("stream"); v != "" { + if s, err := strconv.ParseBool(v); err == nil { + dim.Stream = &s + } + } + if v := c.Query("billing_type"); v != "" { + if bt, err := strconv.ParseInt(v, 10, 8); err == nil { + btVal := int8(bt) + dim.BillingType = &btVal + } + } + limit := 50 if v := c.Query("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 459fd949..458ed35d 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -95,10 +95,6 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -108,10 +104,10 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` - // Sora 存储配额 - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -121,7 +117,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -133,10 +129,6 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -146,10 +138,10 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` - // Sora 存储配额 - SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` @@ -254,10 +246,6 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -265,8 +253,9 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -307,10 +296,6 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -318,8 +303,9 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, + RequireOAuthOnly: req.RequireOAuthOnly, + RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index 4e6179db..cc0c9337 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct { } func oauthPlatformFromPath(c *gin.Context) string { - if strings.Contains(c.FullPath(), "/admin/sora/") { - return service.PlatformSora - } return service.PlatformOpenAI } @@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct { // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token -// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// ExchangeSoraSessionToken exchanges Sora session token to access token -// POST /api/v1/admin/sora/st2at -func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { - var req struct { - SessionToken string `json:"session_token"` - ST string `json:"st"` - ProxyID *int64 `json:"proxy_id"` - } - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - sessionToken := strings.TrimSpace(req.SessionToken) - if sessionToken == "" { - sessionToken = strings.TrimSpace(req.ST) - } - if sessionToken == "" { - response.BadRequest(c, "session_token is required") - return - } - - tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, tokenInfo) -} - -// RefreshAccountToken refreshes token for a specific OpenAI/Sora account +// RefreshAccountToken refreshes token for a specific OpenAI account // POST /api/v1/admin/openai/accounts/:id/refresh -// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth -// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` @@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { name = tokenInfo.Email } if name == "" { - if platform == service.PlatformSora { - name = "Sora OAuth Account" - } else { - name = "OpenAI OAuth Account" - } + name = "OpenAI OAuth Account" } // Create account diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 13ea88d9..c494e5fb 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -35,9 +35,9 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service. type GenerateRedeemCodesRequest struct { Count int `json:"count" binding:"required,min=1,max=100"` Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` - Value float64 `json:"value" binding:"min=0"` - GroupID *int64 `json:"group_id"` // 订阅类型必填 - ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 + Value float64 `json:"value"` + GroupID *int64 `json:"group_id"` // 订阅类型必填 + ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减 } // CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. @@ -45,10 +45,10 @@ type GenerateRedeemCodesRequest struct { type CreateAndRedeemCodeRequest struct { Code string `json:"code" binding:"required,min=3,max=128"` Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容) - Value float64 `json:"value" binding:"required,gt=0"` + Value float64 `json:"value" binding:"required"` UserID int64 `json:"user_id" binding:"required,gt=0"` - GroupID *int64 `json:"group_id"` // subscription 类型必填 - ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0 + GroupID *int64 `json:"group_id"` // subscription 类型必填 + ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减 Notes string `json:"notes"` } @@ -150,8 +150,8 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { response.BadRequest(c, "group_id is required for subscription type") return } - if req.ValidityDays <= 0 { - response.BadRequest(c, "validity_days must be greater than 0 for subscription type") + if req.ValidityDays == 0 { + response.BadRequest(c, "validity_days must not be zero for subscription type") return } } diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go index 0d42f64f..f1f7778f 100644 --- a/backend/internal/handler/admin/redeem_handler_test.go +++ b/backend/internal/handler/admin/redeem_handler_test.go @@ -76,32 +76,38 @@ func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) { assert.Equal(t, http.StatusBadRequest, code) } -func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) { +func TestCreateAndRedeem_SubscriptionRequiresNonZeroValidityDays(t *testing.T) { groupID := int64(5) h := newCreateAndRedeemHandler() - cases := []struct { - name string - validityDays int - }{ - {"zero", 0}, - {"negative", -1}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - code := postCreateAndRedeemValidation(t, h, map[string]any{ - "code": "test-sub-bad-days-" + tc.name, - "type": "subscription", - "value": 29.9, - "user_id": 1, - "group_id": groupID, - "validity_days": tc.validityDays, - }) - - assert.Equal(t, http.StatusBadRequest, code) + // zero should be rejected + t.Run("zero", func(t *testing.T) { + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-bad-days-zero", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": 0, }) - } + + assert.Equal(t, http.StatusBadRequest, code) + }) + + // negative should pass validation (used for refund/reduction) + t.Run("negative_passes_validation", func(t *testing.T) { + code := postCreateAndRedeemValidation(t, h, map[string]any{ + "code": "test-sub-negative-days", + "type": "subscription", + "value": 29.9, + "user_id": 1, + "group_id": groupID, + "validity_days": -7, + }) + + assert.NotEqual(t, http.StatusBadRequest, code, + "negative validity_days should pass validation for refund") + }) } func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 397526a7..06916917 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -41,17 +41,15 @@ type SettingHandler struct { emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService - soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, - soraS3Storage: soraS3Storage, } } @@ -108,7 +106,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, @@ -177,7 +174,6 @@ type UpdateSettingsRequest struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` @@ -566,7 +562,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, - SoraClientEnabled: req.SoraClientEnabled, CustomMenuItems: customMenuJSON, CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, @@ -676,7 +671,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, - SoraClientEnabled: updatedSettings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, @@ -1207,384 +1201,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } -func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { - if settings == nil { - return dto.SoraS3Settings{} - } - return dto.SoraS3Settings{ - Enabled: settings.Enabled, - Endpoint: settings.Endpoint, - Region: settings.Region, - Bucket: settings.Bucket, - AccessKeyID: settings.AccessKeyID, - SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, - Prefix: settings.Prefix, - ForcePathStyle: settings.ForcePathStyle, - CDNURL: settings.CDNURL, - DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, - } -} - -func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { - return dto.SoraS3Profile{ - ProfileID: profile.ProfileID, - Name: profile.Name, - IsActive: profile.IsActive, - Enabled: profile.Enabled, - Endpoint: profile.Endpoint, - Region: profile.Region, - Bucket: profile.Bucket, - AccessKeyID: profile.AccessKeyID, - SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, - Prefix: profile.Prefix, - ForcePathStyle: profile.ForcePathStyle, - CDNURL: profile.CDNURL, - DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, - UpdatedAt: profile.UpdatedAt, - } -} - -func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { - if !enabled { - return nil - } - if strings.TrimSpace(endpoint) == "" { - return fmt.Errorf("S3 Endpoint is required when enabled") - } - if strings.TrimSpace(bucket) == "" { - return fmt.Errorf("S3 Bucket is required when enabled") - } - if strings.TrimSpace(accessKeyID) == "" { - return fmt.Errorf("S3 Access Key ID is required when enabled") - } - if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { - return nil - } - return fmt.Errorf("S3 Secret Access Key is required when enabled") -} - -func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == profileID { - return &items[idx] - } - } - return nil -} - -// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) -// GET /api/v1/admin/settings/sora-s3 -func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { - settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3SettingsDTO(settings)) -} - -// ListSoraS3Profiles 获取 Sora S3 多配置 -// GET /api/v1/admin/settings/sora-s3/profiles -func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { - result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - items := make([]dto.SoraS3Profile, 0, len(result.Items)) - for idx := range result.Items { - items = append(items, toSoraS3ProfileDTO(result.Items[idx])) - } - response.Success(c, dto.ListSoraS3ProfilesResponse{ - ActiveProfileID: result.ActiveProfileID, - Items: items, - }) -} - -// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) -type UpdateSoraS3SettingsRequest struct { - ProfileID string `json:"profile_id"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -type CreateSoraS3ProfileRequest struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - SetActive bool `json:"set_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -type UpdateSoraS3ProfileRequest struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// CreateSoraS3Profile 创建 Sora S3 配置 -// POST /api/v1/admin/settings/sora-s3/profiles -func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { - var req CreateSoraS3ProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if strings.TrimSpace(req.Name) == "" { - response.BadRequest(c, "Name is required") - return - } - if strings.TrimSpace(req.ProfileID) == "" { - response.BadRequest(c, "Profile ID is required") - return - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { - response.BadRequest(c, err.Error()) - return - } - - created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ - ProfileID: req.ProfileID, - Name: req.Name, - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - }, req.SetActive) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, toSoraS3ProfileDTO(*created)) -} - -// UpdateSoraS3Profile 更新 Sora S3 配置 -// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id -func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - - var req UpdateSoraS3ProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if strings.TrimSpace(req.Name) == "" { - response.BadRequest(c, "Name is required") - return - } - - existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - existing := findSoraS3ProfileByID(existingList.Items, profileID) - if existing == nil { - response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) - return - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { - response.BadRequest(c, err.Error()) - return - } - - updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ - Name: req.Name, - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - }) - if updateErr != nil { - response.ErrorFrom(c, updateErr) - return - } - - response.Success(c, toSoraS3ProfileDTO(*updated)) -} - -// DeleteSoraS3Profile 删除 Sora S3 配置 -// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id -func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"deleted": true}) -} - -// SetActiveSoraS3Profile 切换激活 Sora S3 配置 -// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate -func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3ProfileDTO(*active)) -} - -// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) -// PUT /api/v1/admin/settings/sora-s3 -func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { - var req UpdateSoraS3SettingsRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { - response.BadRequest(c, err.Error()) - return - } - - settings := &service.SoraS3Settings{ - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - } - if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { - response.ErrorFrom(c, err) - return - } - - updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3SettingsDTO(updatedSettings)) -} - -// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) -// POST /api/v1/admin/settings/sora-s3/test -func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { - if h.soraS3Storage == nil { - response.Error(c, 500, "S3 存储服务未初始化") - return - } - - var req UpdateSoraS3SettingsRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - if !req.Enabled { - response.BadRequest(c, "S3 未启用,无法测试连接") - return - } - - if req.SecretAccessKey == "" { - if req.ProfileID != "" { - profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err == nil { - profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) - if profile != nil { - req.SecretAccessKey = profile.SecretAccessKey - } - } - } - if req.SecretAccessKey == "" { - existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err == nil { - req.SecretAccessKey = existing.SecretAccessKey - } - } - } - - testCfg := &service.SoraS3Settings{ - Enabled: true, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - } - if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { - response.Error(c, 400, "S3 连接测试失败: "+err.Error()) - return - } - response.Success(c, gin.H{"message": "S3 连接成功"}) -} - // GetRectifierSettings 获取请求整流器配置 // GET /api/v1/admin/settings/rectifier func (h *SettingHandler) GetRectifierSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 7a3135b8..2967b384 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) { } model := c.Query("model") + billingMode := strings.TrimSpace(c.Query("billing_mode")) var requestType *int16 var stream *bool @@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) { RequestType: requestType, Stream: stream, BillingType: billingType, + BillingMode: billingMode, StartTime: startTime, EndTime: endTime, ExactTotal: exactTotal, @@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { } model := c.Query("model") + billingMode := strings.TrimSpace(c.Query("billing_mode")) var requestType *int16 var stream *bool @@ -312,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { RequestType: requestType, Stream: stream, BillingType: billingType, + BillingMode: billingMode, StartTime: &startTime, EndTime: &endTime, } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 998308dd..a357657e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -34,14 +34,13 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` } // UpdateUserRequest represents admin update user request @@ -57,8 +56,7 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` - SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` + GroupRates map[int64]*float64 `json:"group_rates"` } // UpdateBalanceRequest represents balance update request @@ -182,14 +180,13 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, }) if err != nil { response.ErrorFrom(c, err) @@ -216,16 +213,15 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 0b5448af..2eab670e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, - SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, - SoraStorageUsedBytes: u.SoraStorageUsedBytes, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, } } @@ -172,15 +170,12 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, - SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOAuthOnly, + RequirePrivacySet: g.RequirePrivacySet, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -575,6 +570,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { MediaType: l.MediaType, UserAgent: l.UserAgent, CacheTTLOverridden: l.CacheTTLOverridden, + BillingMode: l.BillingMode, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), @@ -602,6 +598,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), UpstreamModel: l.UpstreamModel, + ChannelID: l.ChannelID, + ModelMappingChain: l.ModelMappingChain, + BillingTier: l.BillingTier, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 47bab091..aecbf0c8 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -61,7 +61,6 @@ type SystemSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` @@ -128,49 +127,10 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version"` } -// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) -type SoraS3Settings struct { - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) -type SoraS3Profile struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// ListSoraS3ProfilesResponse Sora S3 配置列表响应 -type ListSoraS3ProfilesResponse struct { - ActiveProfileID string `json:"active_profile_id"` - Items []SoraS3Profile `json:"items"` -} - // OverloadCooldownSettings 529过载冷却配置 DTO type OverloadCooldownSettings struct { Enabled bool `json:"enabled"` @@ -197,10 +157,13 @@ type RectifierSettings struct { // BetaPolicyRule Beta 策略规则 DTO type BetaPolicyRule struct { - BetaToken string `json:"beta_token"` - Action string `json:"action"` - Scope string `json:"scope"` - ErrorMessage string `json:"error_message,omitempty"` + BetaToken string `json:"beta_token"` + Action string `json:"action"` + Scope string `json:"scope"` + ErrorMessage string `json:"error_message,omitempty"` + ModelWhitelist []string `json:"model_whitelist,omitempty"` + FallbackAction string `json:"fallback_action,omitempty"` + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` } // BetaPolicySettings Beta 策略配置 DTO diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8af6990e..82065deb 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,9 +26,7 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` - SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` } type APIKey struct { @@ -84,24 +82,19 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - // Sora 按次计费配置 - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` - // Sora 存储配额 - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` - // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + // 账号过滤控制(仅 OpenAI/Antigravity 平台有效) + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -386,6 +379,9 @@ type UsageLog struct { // Cache TTL Override 标记 CacheTTLOverridden bool `json:"cache_ttl_overridden"` + // BillingMode 计费模式:token/image + BillingMode *string `json:"billing_mode,omitempty"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` @@ -402,6 +398,13 @@ type AdminUsageLog struct { // Omitted when no mapping was applied (requested model was used as-is). UpstreamModel *string `json:"upstream_model,omitempty"` + // ChannelID 渠道 ID + ChannelID *int64 `json:"channel_id,omitempty"` + // ModelMappingChain 模型映射链,如 "a→b→c" + ModelMappingChain *string `json:"model_mapping_chain,omitempty"` + // BillingTier 计费层级标签(per_request/image 模式) + BillingTier *string `json:"billing_tier,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index b1200988..a897bc40 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -31,7 +31,7 @@ const ( // ────────────────────────────────────────────────────────── // NormalizeInboundEndpoint maps a raw request path (which may carry -// prefixes like /antigravity, /openai, /sora) to its canonical form. +// prefixes like /antigravity, /openai) to its canonical form. // // "/antigravity/v1/messages" → "/v1/messages" // "/v1/chat/completions" → "/v1/chat/completions" @@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string { // such as /v1/responses/compact preserved from the raw URL). // - Anthropic → /v1/messages // - Gemini → /v1beta/models -// - Sora → /v1/chat/completions +// - Antigravity → /v1/messages (Claude) or gemini (Gemini) // - Antigravity routes may target either Claude or Gemini, so the // inbound endpoint is used to distinguish. func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { @@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { case service.PlatformGemini: return EndpointGeminiModels - case service.PlatformSora: - return EndpointChatCompletions - case service.PlatformAntigravity: // Antigravity accounts serve both Claude and Gemini. if inbound == EndpointGeminiModels { diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index a3767ac4..1519bc9e 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) { {"/v1/responses", EndpointResponses}, {"/v1beta/models", EndpointGeminiModels}, - // Prefixed paths (antigravity, openai, sora). + // Prefixed paths (antigravity, openai). {"/antigravity/v1/messages", EndpointMessages}, {"/openai/v1/responses", EndpointResponses}, {"/openai/v1/responses/compact", EndpointResponses}, - {"/sora/v1/chat/completions", EndpointChatCompletions}, {"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels}, // Gin route patterns with wildcards. @@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { // Gemini. {"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels}, - // Sora. - {"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions}, - // OpenAI — always /v1/responses. {"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses}, {"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"}, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9..59619d50 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { @@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { parsedReq.OnUpstreamAccepted = queueRelease // ===== 用户消息串行队列 END ===== + // 应用渠道模型映射到请求 + if channelMapping.Mapped { + parsedReq.Model = channelMapping.MappedModel + parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel) + body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -847,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) { platform = forcedPlatform } - if platform == service.PlatformSora { - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": service.DefaultSoraModels(h.cfg), - }) - return - } - // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index da376036..be267332 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // Claude Code only restriction if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", @@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { fs := NewFailoverState(h.maxAccountSwitches, false) for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) @@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.cc.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index d146d724..e908eb9e 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // Claude Code only restriction: // /v1/responses is never a Claude Code endpoint. // When claude_code_only is enabled, this endpoint is rejected. @@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { fs := NewFailoverState(h.maxAccountSwitches, false) for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) @@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.responses.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 69c8d1d5..4caef955 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService + nil, // resolver ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 5dc03b6d..d200c17c 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) + reqModel := modelName // 保存映射前的原始模型名 + if channelMapping.Mapped { + modelName = channelMapping.MappedModel + } + // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) @@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gemini_v1beta.models"), @@ -674,6 +682,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go index 82b30ee4..29d7cc41 100644 --- a/backend/internal/handler/gemini_v1beta_handler_test.go +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -3,6 +3,7 @@ package handler import ( + "net/http" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { }) } } + +func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res)) +} + +func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.False(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} + +func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{ + StatusCode: http.StatusForbidden, + Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}}, + Body: []byte("insufficient authentication scopes"), + } + require.True(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2467eac..d4c349fb 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -30,6 +30,7 @@ type AdminHandlers struct { TLSFingerprintProfile *admin.TLSFingerprintProfileHandler APIKey *admin.AdminAPIKeyHandler ScheduledTest *admin.ScheduledTestHandler + Channel *admin.ChannelHandler } // Handlers contains all HTTP handlers @@ -44,8 +45,6 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler - SoraGateway *SoraGatewayHandler - SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 0c94aa21..991cbb91 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } @@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { forwardStart := time.Now() defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.chat_completions"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index ae70cee4..4747ccfe 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { return @@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() @@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) - result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMappingMsg.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel) + } + result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + // 解析渠道级模型映射 + channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) + var currentUserRelease func() var currentAccountRelease func() releaseTurnSlots := func() { @@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), @@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { }, } - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + // 应用渠道模型映射到 WebSocket 首条消息 + wsFirstMessage := firstMessage + if channelMappingWS.Mapped { + wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) closeStatus, closeReason := summarizeWSCloseErrorForLog(err) reqLog.Warn("openai.websocket_proxy_failed", diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2c999cf1..977c2301 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: h.version, }) diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go deleted file mode 100644 index 80acc833..00000000 --- a/backend/internal/handler/sora_client_handler.go +++ /dev/null @@ -1,979 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" -) - -const ( - // 上游模型缓存 TTL - modelCacheTTL = 1 * time.Hour // 上游获取成功 - modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) -) - -// SoraClientHandler 处理 Sora 客户端 API 请求。 -type SoraClientHandler struct { - genService *service.SoraGenerationService - quotaService *service.SoraQuotaService - s3Storage *service.SoraS3Storage - soraGatewayService *service.SoraGatewayService - gatewayService *service.GatewayService - mediaStorage *service.SoraMediaStorage - apiKeyService *service.APIKeyService - - // 上游模型缓存 - modelCacheMu sync.RWMutex - cachedFamilies []service.SoraModelFamily - modelCacheTime time.Time - modelCacheUpstream bool // 是否来自上游(决定 TTL) -} - -// NewSoraClientHandler 创建 Sora 客户端 Handler。 -func NewSoraClientHandler( - genService *service.SoraGenerationService, - quotaService *service.SoraQuotaService, - s3Storage *service.SoraS3Storage, - soraGatewayService *service.SoraGatewayService, - gatewayService *service.GatewayService, - mediaStorage *service.SoraMediaStorage, - apiKeyService *service.APIKeyService, -) *SoraClientHandler { - return &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - s3Storage: s3Storage, - soraGatewayService: soraGatewayService, - gatewayService: gatewayService, - mediaStorage: mediaStorage, - apiKeyService: apiKeyService, - } -} - -// GenerateRequest 生成请求。 -type GenerateRequest struct { - Model string `json:"model" binding:"required"` - Prompt string `json:"prompt" binding:"required"` - MediaType string `json:"media_type"` // video / image,默认 video - VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) - ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) - APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID -} - -// Generate 异步生成 — 创建 pending 记录后立即返回。 -// POST /api/v1/sora/generate -func (h *SoraClientHandler) Generate(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - var req GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) - return - } - - if req.MediaType == "" { - req.MediaType = "video" - } - req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) - - // 并发数检查(最多 3 个) - activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - if activeCount >= 3 { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - - // 配额检查(粗略检查,实际文件大小在上传后才知道) - if h.quotaService != nil { - if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusForbidden, err.Error()) - return - } - } - - // 获取 API Key ID 和 Group ID - var apiKeyID *int64 - var groupID *int64 - - if req.APIKeyID != nil && h.apiKeyService != nil { - // 前端传递了 api_key_id,需要校验 - apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) - if err != nil { - response.Error(c, http.StatusBadRequest, "API Key 不存在") - return - } - if apiKey.UserID != userID { - response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") - return - } - if apiKey.Status != service.StatusAPIKeyActive { - response.Error(c, http.StatusForbidden, "API Key 不可用") - return - } - apiKeyID = &apiKey.ID - groupID = apiKey.GroupID - } else if id, ok := c.Get("api_key_id"); ok { - // 兼容 API Key 认证路径(/sora/v1/ 网关路由) - if v, ok := id.(int64); ok { - apiKeyID = &v - } - } - - gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) - if err != nil { - if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - response.ErrorFrom(c, err) - return - } - - // 启动后台异步生成 goroutine - go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) - - response.Success(c, gin.H{ - "generation_id": gen.ID, - "status": gen.Status, - }) -} - -// processGeneration 后台异步执行 Sora 生成任务。 -// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 -func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - // 标记为生成中 - if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", - genID, - userID, - groupIDForLog(groupID), - model, - mediaType, - videoCount, - strings.TrimSpace(imageInput) != "", - len(strings.TrimSpace(prompt)), - ) - - // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 - if groupID == nil { - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - } - - if h.gatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") - return - } - - // 选择 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", - genID, - userID, - groupIDForLog(groupID), - model, - err, - ) - _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) - return - } - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", - genID, - userID, - groupIDForLog(groupID), - model, - account.ID, - account.Name, - account.Platform, - account.Type, - ) - - // 构建 chat completions 请求体(非流式) - body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) - - if h.soraGatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") - return - } - - // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) - recorder := httptest.NewRecorder() - mockGinCtx, _ := gin.CreateTestContext(recorder) - mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) - - // 调用 Forward(非流式) - result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - err, - ) - // 检查是否已取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - return - } - _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) - return - } - - // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) - mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) - if mediaURL == "" { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - ) - _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") - return - } - - // 检查任务是否已被取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) - return - } - - // 三层降级存储:S3 → 本地 → 上游临时 URL - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) - - usageAdded := false - if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") - return - } - _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 - gen, _ = h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - - // 标记完成 - if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) -} - -// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 -func (h *SoraClientHandler) storeMediaWithDegradation( - ctx context.Context, userID int64, mediaType string, - mediaURL string, mediaURLs []string, -) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { - urls := mediaURLs - if len(urls) == 0 { - urls = []string{mediaURL} - } - - // 第一层:尝试 S3 - if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { - keys := make([]string, 0, len(urls)) - var totalSize int64 - allOK := true - for _, u := range urls { - key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) - allOK = false - // 清理已上传的文件 - if len(keys) > 0 { - _ = h.s3Storage.DeleteObjects(ctx, keys) - } - break - } - keys = append(keys, key) - totalSize += size - } - if allOK && len(keys) > 0 { - accessURLs := make([]string, 0, len(keys)) - for _, key := range keys { - accessURL, err := h.s3Storage.GetAccessURL(ctx, key) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) - _ = h.s3Storage.DeleteObjects(ctx, keys) - allOK = false - break - } - accessURLs = append(accessURLs, accessURL) - } - if allOK && len(accessURLs) > 0 { - return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize - } - } - } - - // 第二层:尝试本地存储 - if h.mediaStorage != nil && h.mediaStorage.Enabled() { - storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) - if err == nil && len(storedPaths) > 0 { - firstPath := storedPaths[0] - totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) - if sizeErr != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) - } - return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) - } - - // 第三层:保留上游临时 URL - return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 -} - -// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 -func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { - body := map[string]any{ - "model": model, - "messages": []map[string]string{ - {"role": "user", "content": prompt}, - }, - "stream": false, - } - if imageInput != "" { - body["image_input"] = imageInput - } - if videoCount > 1 { - body["video_count"] = videoCount - } - b, _ := json.Marshal(body) - return b -} - -func normalizeVideoCount(mediaType string, videoCount int) int { - if mediaType != "video" { - return 1 - } - if videoCount <= 0 { - return 1 - } - if videoCount > 3 { - return 3 - } - return videoCount -} - -// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 -// OAuth 路径:ForwardResult.MediaURL 已填充。 -// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 -func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { - // 优先从 ForwardResult 获取(OAuth 路径) - if result != nil && result.MediaURL != "" { - // 尝试从响应体获取完整 URL 列表 - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - return result.MediaURL, []string{result.MediaURL} - } - - // 从响应体解析(APIKey 路径) - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - - return "", nil -} - -// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 -func parseMediaURLsFromBody(body []byte) []string { - if len(body) == 0 { - return nil - } - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return nil - } - - // 优先 media_urls(多图数组) - if rawURLs, ok := resp["media_urls"]; ok { - if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { - urls := make([]string, 0, len(arr)) - for _, item := range arr { - if s, ok := item.(string); ok && s != "" { - urls = append(urls, s) - } - } - if len(urls) > 0 { - return urls - } - } - } - - // 回退到 media_url(单个 URL) - if url, ok := resp["media_url"].(string); ok && url != "" { - return []string{url} - } - - return nil -} - -// ListGenerations 查询生成记录列表。 -// GET /api/v1/sora/generations -func (h *SoraClientHandler) ListGenerations(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - - params := service.SoraGenerationListParams{ - UserID: userID, - Status: c.Query("status"), - StorageType: c.Query("storage_type"), - MediaType: c.Query("media_type"), - Page: page, - PageSize: pageSize, - } - - gens, total, err := h.genService.List(c.Request.Context(), params) - if err != nil { - response.ErrorFrom(c, err) - return - } - - // 为 S3 记录动态生成预签名 URL - for _, gen := range gens { - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - } - - response.Success(c, gin.H{ - "data": gens, - "total": total, - "page": page, - }) -} - -// GetGeneration 查询生成记录详情。 -// GET /api/v1/sora/generations/:id -func (h *SoraClientHandler) GetGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - response.Success(c, gen) -} - -// DeleteGeneration 删除生成记录。 -// DELETE /api/v1/sora/generations/:id -func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 - if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { - paths := gen.MediaURLs - if len(paths) == 0 && gen.MediaURL != "" { - paths = []string{gen.MediaURL} - } - if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) - } - } - - if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已删除"}) -} - -// GetQuota 查询用户存储配额。 -// GET /api/v1/sora/quota -func (h *SoraClientHandler) GetQuota(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - if h.quotaService == nil { - response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) - return - } - - quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, quota) -} - -// CancelGeneration 取消生成任务。 -// POST /api/v1/sora/generations/:id/cancel -func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - // 权限校验 - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - _ = gen - - if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { - if errors.Is(err, service.ErrSoraGenerationNotActive) { - response.Error(c, http.StatusConflict, "任务已结束,无法取消") - return - } - response.Error(c, http.StatusBadRequest, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已取消"}) -} - -// SaveToStorage 手动保存 upstream 记录到 S3。 -// POST /api/v1/sora/generations/:id/save -func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - if gen.StorageType != service.SoraStorageTypeUpstream { - response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") - return - } - if gen.MediaURL == "" { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { - response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") - return - } - - sourceURLs := gen.MediaURLs - if len(sourceURLs) == 0 && gen.MediaURL != "" { - sourceURLs = []string{gen.MediaURL} - } - if len(sourceURLs) == 0 { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - uploadedKeys := make([]string, 0, len(sourceURLs)) - accessURLs := make([]string, 0, len(sourceURLs)) - var totalSize int64 - - for _, sourceURL := range sourceURLs { - objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) - if uploadErr != nil { - if len(uploadedKeys) > 0 { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - } - var upstreamErr *service.UpstreamDownloadError - if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { - response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") - return - } - response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) - return - } - accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) - if err != nil { - uploadedKeys = append(uploadedKeys, objectKey) - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) - return - } - uploadedKeys = append(uploadedKeys, objectKey) - accessURLs = append(accessURLs, accessURL) - totalSize += fileSize - } - - usageAdded := false - if totalSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - if err := h.genService.UpdateStorageForCompleted( - c.Request.Context(), - id, - accessURLs[0], - accessURLs, - service.SoraStorageTypeS3, - uploadedKeys, - totalSize, - ); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) - } - response.ErrorFrom(c, err) - return - } - - response.Success(c, gin.H{ - "message": "已保存到 S3", - "object_key": uploadedKeys[0], - "object_keys": uploadedKeys, - }) -} - -// GetStorageStatus 返回存储状态。 -// GET /api/v1/sora/storage-status -func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { - s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) - s3Healthy := false - if s3Enabled { - s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) - } - localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() - response.Success(c, gin.H{ - "s3_enabled": s3Enabled, - "s3_healthy": s3Healthy, - "local_enabled": localEnabled, - }) -} - -func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { - switch storageType { - case service.SoraStorageTypeS3: - if h.s3Storage != nil && len(s3Keys) > 0 { - if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) - } - } - case service.SoraStorageTypeLocal: - if h.mediaStorage != nil && len(localPaths) > 0 { - if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) - } - } - } -} - -// getUserIDFromContext 从 gin 上下文中提取用户 ID。 -func getUserIDFromContext(c *gin.Context) int64 { - if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { - return subject.UserID - } - - if id, ok := c.Get("user_id"); ok { - switch v := id.(type) { - case int64: - return v - case float64: - return int64(v) - case string: - n, _ := strconv.ParseInt(v, 10, 64) - return n - } - } - // 尝试从 JWT claims 获取 - if id, ok := c.Get("userID"); ok { - if v, ok := id.(int64); ok { - return v - } - } - return 0 -} - -func groupIDForLog(groupID *int64) int64 { - if groupID == nil { - return 0 - } - return *groupID -} - -func trimForLog(raw string, maxLen int) string { - trimmed := strings.TrimSpace(raw) - if maxLen <= 0 || len(trimmed) <= maxLen { - return trimmed - } - return trimmed[:maxLen] + "...(truncated)" -} - -// GetModels 获取可用 Sora 模型家族列表。 -// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 -// GET /api/v1/sora/models -func (h *SoraClientHandler) GetModels(c *gin.Context) { - families := h.getModelFamilies(c.Request.Context()) - response.Success(c, families) -} - -// getModelFamilies 获取模型家族列表(带缓存)。 -func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { - // 读锁检查缓存 - h.modelCacheMu.RLock() - ttl := modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - families := h.cachedFamilies - h.modelCacheMu.RUnlock() - return families - } - h.modelCacheMu.RUnlock() - - // 写锁更新缓存 - h.modelCacheMu.Lock() - defer h.modelCacheMu.Unlock() - - // double-check - ttl = modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - return h.cachedFamilies - } - - // 尝试从上游获取 - families, err := h.fetchUpstreamModels(ctx) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) - families = service.BuildSoraModelFamilies() - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = false - return families - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = true - return families -} - -// fetchUpstreamModels 从上游 Sora API 获取模型列表。 -func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { - if h.gatewayService == nil { - return nil, fmt.Errorf("gatewayService 未初始化") - } - - // 设置 ForcePlatform 用于 Sora 账号选择 - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - - // 选择一个 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") - if err != nil { - return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) - } - - // 仅支持 API Key 类型账号 - if account.Type != service.AccountTypeAPIKey { - return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) - } - - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return nil, fmt.Errorf("账号缺少 api_key") - } - - baseURL := account.GetBaseURL() - if baseURL == "" { - return nil, fmt.Errorf("账号缺少 base_url") - } - - // 构建上游模型列表请求 - modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" - - reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) - if err != nil { - return nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+apiKey) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("请求上游失败: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - // 解析 OpenAI 格式的模型列表 - var modelsResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.Unmarshal(body, &modelsResp); err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) - } - - if len(modelsResp.Data) == 0 { - return nil, fmt.Errorf("上游返回空模型列表") - } - - // 提取模型 ID - modelIDs := make([]string, 0, len(modelsResp.Data)) - for _, m := range modelsResp.Data { - modelIDs = append(modelIDs, m.ID) - } - - // 转换为模型家族 - families := service.BuildSoraModelFamiliesFromIDs(modelIDs) - if len(families) == 0 { - return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") - } - - return families, nil -} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go deleted file mode 100644 index fe035b6f..00000000 --- a/backend/internal/handler/sora_client_handler_test.go +++ /dev/null @@ -1,3178 +0,0 @@ -//go:build unit - -package handler - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -// ==================== Stub: SoraGenerationRepository ==================== - -var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) - -type stubSoraGenRepo struct { - gens map[int64]*service.SoraGeneration - nextID int64 - createErr error - getErr error - updateErr error - deleteErr error - listErr error - countErr error - countValue int64 - - // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 - updateCallCount *int32 - updateFailAfterN int32 - - // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus - getByIDCallCount int32 - getByIDOverrideAfterN int32 // 0 = 不覆盖 - getByIDOverrideStatus string -} - -func newStubSoraGenRepo() *stubSoraGenRepo { - return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} -} - -func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { - if r.createErr != nil { - return r.createErr - } - gen.ID = r.nextID - r.nextID++ - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { - if r.getErr != nil { - return nil, r.getErr - } - gen, ok := r.gens[id] - if !ok { - return nil, fmt.Errorf("not found") - } - // 条件性状态覆盖:模拟外部取消等场景 - if r.getByIDOverrideAfterN > 0 { - n := atomic.AddInt32(&r.getByIDCallCount, 1) - if n > r.getByIDOverrideAfterN { - cp := *gen - cp.Status = r.getByIDOverrideStatus - return &cp, nil - } - } - return gen, nil -} -func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { - // 条件性失败:前 N 次成功,之后失败 - if r.updateCallCount != nil { - n := atomic.AddInt32(r.updateCallCount, 1) - if n > r.updateFailAfterN { - return fmt.Errorf("conditional update error (call #%d)", n) - } - } - if r.updateErr != nil { - return r.updateErr - } - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { - if r.deleteErr != nil { - return r.deleteErr - } - delete(r.gens, id) - return nil -} -func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - if r.listErr != nil { - return nil, 0, r.listErr - } - var result []*service.SoraGeneration - for _, gen := range r.gens { - if gen.UserID != params.UserID { - continue - } - result = append(result, gen) - } - return result, int64(len(result)), nil -} -func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { - if r.countErr != nil { - return 0, r.countErr - } - return r.countValue, nil -} - -// ==================== 辅助函数 ==================== - -func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { - genService := service.NewSoraGenerationService(repo, nil, nil) - return &SoraClientHandler{genService: genService} -} - -func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - if body != "" { - c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) - c.Request.Header.Set("Content-Type", "application/json") - } else { - c.Request = httptest.NewRequest(method, path, nil) - } - if userID > 0 { - c.Set("user_id", userID) - } - return c, rec -} - -func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { - t.Helper() - var resp map[string]any - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - return resp -} - -// ==================== 纯函数测试: buildAsyncRequestBody ==================== - -func TestBuildAsyncRequestBody(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "sora2-landscape-10s", parsed["model"]) - require.Equal(t, false, parsed["stream"]) - - msgs := parsed["messages"].([]any) - require.Len(t, msgs, 1) - msg := msgs[0].(map[string]any) - require.Equal(t, "user", msg["role"]) - require.Equal(t, "一只猫在跳舞", msg["content"]) -} - -func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "gpt-image", parsed["model"]) - msgs := parsed["messages"].([]any) - msg := msgs[0].(map[string]any) - require.Equal(t, "", msg["content"]) -} - -func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) -} - -func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, float64(3), parsed["video_count"]) -} - -func TestNormalizeVideoCount(t *testing.T) { - require.Equal(t, 1, normalizeVideoCount("video", 0)) - require.Equal(t, 2, normalizeVideoCount("video", 2)) - require.Equal(t, 3, normalizeVideoCount("video", 5)) - require.Equal(t, 1, normalizeVideoCount("image", 3)) -} - -// ==================== 纯函数测试: parseMediaURLsFromBody ==================== - -func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) - require.Equal(t, []string{"https://a.com/video.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody(nil)) - require.Nil(t, parseMediaURLsFromBody([]byte{})) -} - -func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) -} - -func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) -} - -func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { - body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` - urls := parseMediaURLsFromBody([]byte(body)) - require.Len(t, urls, 2) - require.Equal(t, "https://multi.com/a.mp4", urls[0]) -} - -func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) -} - -func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { - // media_urls 不是 string 数组 - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) -} - -func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) -} - -// ==================== 纯函数测试: extractMediaURLsFromResult ==================== - -func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://oauth.com/video.mp4", url) - require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://body.com/1.mp4", url) - require.Len(t, urls, 2) -} - -func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Equal(t, "https://upstream.com/video.mp4", url) - require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { - result := &service.ForwardResult{MediaURL: ""} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -// ==================== getUserIDFromContext ==================== - -func TestGetUserIDFromContext_Int64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", int64(42)) - require.Equal(t, int64(42), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_AuthSubject(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) - require.Equal(t, int64(777), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_Float64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", float64(99)) - require.Equal(t, int64(99), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_String(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "123") - require.Equal(t, int64(123), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("userID", int64(55)) - require.Equal(t, int64(55), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_NoID(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_InvalidString(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "not-a-number") - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -// ==================== Handler: Generate ==================== - -func TestGenerate_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) - h.Generate(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGenerate_BadRequest_MissingModel(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_TooManyRequests(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countValue = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -func TestGenerate_CountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_Success(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - require.Equal(t, "pending", data["status"]) -} - -func TestGenerate_DefaultMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "video", repo.gens[1].MediaType) -} - -func TestGenerate_ImageMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "image", repo.gens[1].MediaType) -} - -func TestGenerate_CreatePendingError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.createErr = fmt.Errorf("create failed") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestGenerate_APIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(42)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_ConcurrencyBoundary(t *testing.T) { - // activeCount == 2 应该允许 - repo := newStubSoraGenRepo() - repo.countValue = 2 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Handler: ListGenerations ==================== - -func TestListGenerations_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) - h.ListGenerations(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestListGenerations_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} - repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} - repo.nextID = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - items := data["data"].([]any) - require.Len(t, items, 2) - require.Equal(t, float64(2), data["total"]) -} - -func TestListGenerations_ListError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.listErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestListGenerations_DefaultPagination(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - // 不传分页参数,应默认 page=1 page_size=20 - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["page"]) -} - -// ==================== Handler: GetGeneration ==================== - -func TestGetGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.GetGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGetGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["id"]) -} - -// ==================== Handler: DeleteGeneration ==================== - -func TestDeleteGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestDeleteGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestDeleteGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -// ==================== Handler: CancelGeneration ==================== - -func TestCancelGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestCancelGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestCancelGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_Pending(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Generating(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Completed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Failed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Cancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -// ==================== Handler: GetQuota ==================== - -func TestGetQuota_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) - h.GetQuota(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetQuota_NilQuotaService(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, "unlimited", data["source"]) -} - -// ==================== Handler: GetModels ==================== - -func TestGetModels(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) - h.GetModels(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].([]any) - require.Len(t, data, 4) - // 验证类型分布 - videoCount, imageCount := 0, 0 - for _, item := range data { - m := item.(map[string]any) - if m["type"] == "video" { - videoCount++ - } else if m["type"] == "image" { - imageCount++ - } - } - require.Equal(t, 3, videoCount) - require.Equal(t, 1, imageCount) -} - -// ==================== Handler: GetStorageStatus ==================== - -func TestGetStorageStatus_NilS3(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, false, data["local_enabled"]) -} - -func TestGetStorageStatus_LocalEnabled(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, true, data["local_enabled"]) -} - -// ==================== Handler: SaveToStorage ==================== - -func TestSaveToStorage_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestSaveToStorage_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestSaveToStorage_NotUpstream(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_EmptyMediaURL(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_S3Nil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusServiceUnavailable, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "云存储") -} - -func TestSaveToStorage_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== storeMediaWithDegradation — nil guard 路径 ==================== - -func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) - require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://a.com/1.mp4", url) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { - h := &SoraClientHandler{} - url, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) -} - -// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== - -var _ service.UserRepository = (*stubUserRepoForHandler)(nil) - -type stubUserRepoForHandler struct { - users map[int64]*service.User - updateErr error -} - -func newStubUserRepoForHandler() *stubUserRepoForHandler { - return &stubUserRepoForHandler{users: make(map[int64]*service.User)} -} - -func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { - if u, ok := r.users[id]; ok { - return u, nil - } - return nil, fmt.Errorf("user not found") -} -func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { - if r.updateErr != nil { - return r.updateErr - } - r.users[user.ID] = user - return nil -} -func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } -func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } -func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - return nil -} -func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { - return nil -} - -// ==================== NewSoraClientHandler ==================== - -func TestNewSoraClientHandler(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) -} - -func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) - require.Nil(t, h.apiKeyService) -} - -// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== - -var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) - -type stubAPIKeyRepoForHandler struct { - keys map[int64]*service.APIKey - getErr error -} - -func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { - return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} -} - -func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { - if r.getErr != nil { - return nil, r.getErr - } - if k, ok := r.keys[id]; ok { - return k, nil - } - return nil, fmt.Errorf("api key not found: %d", id) -} -func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { - return "", 0, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { - var updated int64 - for id, key := range r.keys { - if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID { - continue - } - clone := *key - gid := newGroupID - clone.GroupID = &gid - r.keys[id] = &clone - updated++ - } - return updated, nil -} -func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) { - return nil, nil -} - -// newTestAPIKeyService 创建测试用的 APIKeyService -func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { - return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) -} - -// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== - -func TestGenerate_WithAPIKeyID_Success(t *testing.T) { - // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(5) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - - // 验证 api_key_id 已关联到生成记录 - gen := repo.gens[1] - require.NotNil(t, gen.APIKeyID) - require.Equal(t, int64(42), *gen.APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { - // 前端传递不存在的 api_key_id → 400 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不存在") -} - -func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { - // 前端传递别人的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 999, // 属于 user 999 - Status: service.StatusAPIKeyActive, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不属于") -} - -func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { - // 前端传递已禁用的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyDisabled, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不可用") -} - -func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { - // 前端传递配额耗尽的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyQuotaExhausted, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { - // 前端传递已过期的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyExpired, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { - // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - h := &SoraClientHandler{genService: genService} // apiKeyService = nil - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { - // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: nil, // 无分组 - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { - // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { - // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(10) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { - // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(99)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 context 中的 api_key_id=99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(99), *repo.gens[1].APIKeyID) -} - -func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { - // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 - // api_key_id=0 不存在 → 400 - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== - -func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { - // groupID 不为 nil → 不设置 ForcePlatform - // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - gid := int64(5) - h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { - // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -// ==================== GenerateRequest JSON 解析 ==================== - -func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { - // 验证 api_key_id 在 JSON 中正确解析为 *int64 - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) - require.NoError(t, err) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(42), *req.APIKeyID) -} - -func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { - // 不传 api_key_id → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { - // api_key_id: null → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { - // 全字段解析 - var req GenerateRequest - err := json.Unmarshal([]byte(`{ - "model":"sora2-landscape-10s", - "prompt":"test prompt", - "media_type":"video", - "video_count":2, - "image_input":"data:image/png;base64,abc", - "api_key_id":100 - }`), &req) - require.NoError(t, err) - require.Equal(t, "sora2-landscape-10s", req.Model) - require.Equal(t, "test prompt", req.Prompt) - require.Equal(t, "video", req.MediaType) - require.Equal(t, 2, req.VideoCount) - require.Equal(t, "data:image/png;base64,abc", req.ImageInput) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(100), *req.APIKeyID) -} - -func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { - // api_key_id 为 nil 时 JSON 序列化应省略 - req := GenerateRequest{Model: "sora2", Prompt: "test"} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - _, hasAPIKeyID := parsed["api_key_id"] - require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") -} - -func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { - // api_key_id 不为 nil 时 JSON 序列化应包含 - id := int64(42) - req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - require.Equal(t, float64(42), parsed["api_key_id"]) -} - -// ==================== GetQuota: 有配额服务 ==================== - -func TestGetQuota_WithQuotaService_Success(t *testing.T) { - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10 * 1024 * 1024, - SoraStorageUsedBytes: 3 * 1024 * 1024, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, "user", data["source"]) - require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) - require.Equal(t, float64(3*1024*1024), data["used_bytes"]) -} - -func TestGetQuota_WithQuotaService_Error(t *testing.T) { - // 用户不存在时 GetQuota 返回错误 - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) - h.GetQuota(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== Generate: 配额检查 ==================== - -func TestGenerate_QuotaCheckFailed(t *testing.T) { - // 配额超限时返回 429 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 1024, - SoraStorageUsedBytes: 1025, // 已超限 - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -func TestGenerate_QuotaCheckPassed(t *testing.T) { - // 配额充足时允许生成 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== - -var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) - -type stubSettingRepoForHandler struct { - values map[string]string -} - -func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { - if values == nil { - values = make(map[string]string) - } - return &stubSettingRepoForHandler{values: values} -} - -func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { - if v, ok := r.values[key]; ok { - return &service.Setting{Key: key, Value: v}, nil - } - return nil, service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { - if v, ok := r.values[key]; ok { - return v, nil - } - return "", service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { - r.values[key] = value - return nil -} -func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { - result := make(map[string]string) - for _, k := range keys { - if v, ok := r.values[k]; ok { - result[k] = v - } - } - return result, nil -} -func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { - for k, v := range settings { - r.values[k] = v - } - return nil -} -func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { - return r.values, nil -} -func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { - delete(r.values, key) - return nil -} - -// ==================== S3 / MediaStorage 辅助函数 ==================== - -// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 -func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { - settingRepo := newStubSettingRepoForHandler(map[string]string{ - "sora_s3_enabled": "true", - "sora_s3_endpoint": endpoint, - "sora_s3_region": "us-east-1", - "sora_s3_bucket": "test-bucket", - "sora_s3_access_key_id": "AKIATEST", - "sora_s3_secret_access_key": "test-secret", - "sora_s3_prefix": "sora", - "sora_s3_force_path_style": "true", - }) - settingService := service.NewSettingService(settingRepo, &config.Config{}) - return service.NewSoraS3Storage(settingService) -} - -// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 -func newFakeSourceServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "video/mp4") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("fake video data for test")) - })) -} - -// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 -// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 -func newFakeS3Server(mode string) *httptest.Server { - var counter atomic.Int32 - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.Copy(io.Discard, r.Body) - _ = r.Body.Close() - - switch mode { - case "ok": - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - case "fail": - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - case "fail-second": - n := counter.Add(1) - if n <= 1 { - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - } - } - })) -} - -// ==================== processGeneration 直接调用测试 ==================== - -func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - repo.updateErr = fmt.Errorf("db error") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" - // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed - // 因此 ErrorMessage 为空(证明未调用 MarkFailed) - require.Equal(t, "generating", repo.gens[1].Status) - require.Empty(t, repo.gens[1].ErrorMessage) -} - -func TestProcessGeneration_GatewayServiceNil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - // gatewayService 未设置 → MarkFailed - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -// ==================== storeMediaWithDegradation: S3 路径 ==================== - -func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 1) - require.NotEmpty(t, s3Keys[0]) - require.Len(t, storedURLs, 1) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURL, fakeS3.URL) - require.Contains(t, storedURL, "/test-bucket/") - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 2) - require.Len(t, storedURLs, 2) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURLs[0], fakeS3.URL) - require.Contains(t, storedURLs[1], fakeS3.URL) - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { - // 上游返回 404 → 下载失败 → S3 上传不会开始 - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) - defer badSource.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -// ==================== storeMediaWithDegradation: 本地存储路径 ==================== - -func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { - // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: "/dev/null/invalid_dir", - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - // 本地存储失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeLocal, storageType) - require.Nil(t, s3Keys) // 本地存储不返回 S3 keys -} - -func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{ - s3Storage: s3Storage, - mediaStorage: mediaStorage, - } - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败 → 本地存储成功 - require.Equal(t, service.SoraStorageTypeLocal, storageType) -} - -// ==================== SaveToStorage: S3 路径 ==================== - -func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "S3") -} - -func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { - expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer expiredServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: expiredServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusGone, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "过期") -} - -func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Contains(t, data["message"], "S3") - require.NotEmpty(t, data["object_key"]) - // 验证记录已更新为 S3 存储 - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) -} - -func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{ - sourceServer.URL + "/v1.mp4", - sourceServer.URL + "/v2.mp4", - }, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Len(t, data["object_keys"].([]any), 2) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.Len(t, repo.gens[1].S3ObjectKeys, 2) - require.Len(t, repo.gens[1].MediaURLs, 2) -} - -func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 100 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - // 验证配额已累加 - require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) -} - -func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== GetStorageStatus: S3 路径 ==================== - -func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { - // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) -} - -func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, true, data["s3_healthy"]) -} - -// ==================== Stub: AccountRepository (用于 GatewayService) ==================== - -var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) - -type stubAccountRepoForHandler struct { - accounts []service.Account -} - -func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { - for i := range r.accounts { - if r.accounts[i].ID == id { - return &r.accounts[i], nil - } - } - return nil, fmt.Errorf("account not found") -} -func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { - return false, nil -} -func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } -func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { - return nil -} -func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { - return 0, nil -} -func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } -func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { - return nil -} -func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { - return 0, nil -} - -func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { - return nil -} - -func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { - return nil -} - -// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== - -var _ service.SoraClient = (*stubSoraClientForHandler)(nil) - -type stubSoraClientForHandler struct { - videoStatus *service.SoraVideoTaskStatus -} - -func (s *stubSoraClientForHandler) Enabled() bool { return true } -func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { - return s.videoStatus, nil -} - -// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== - -// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 -func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { - return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - ) -} - -// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 -func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - return service.NewSoraGatewayService(soraClient, nil, nil, cfg) -} - -// ==================== processGeneration: 更多路径测试 ==================== - -func TestProcessGeneration_SelectAccountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // 提供可用账号使 SelectAccountForModel 成功 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // soraGatewayService 为 nil - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") -} - -func TestProcessGeneration_ForwardError(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回视频任务失败 - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "content policy violation", - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") -} - -func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration - // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回 completed 但无 URL - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: nil, // 无 URL - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") -} - -func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) - // 第 2 次返回 "cancelled" 状态,模拟外部取消 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - // 无 S3 和本地存储,降级到 upstream - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].MediaURL) -} - -func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{sourceServer.URL + "/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - s3Storage := newS3StorageForHandler(fakeS3.URL) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - s3Storage: s3Storage, - quotaService: quotaService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].S3ObjectKeys) - require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) - // 验证配额已累加 - require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) -} - -func TestProcessGeneration_MarkCompletedFails(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 - repo.updateCallCount = new(int32) - repo.updateFailAfterN = 1 - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 - // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 - // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 - require.Equal(t, "completed", repo.gens[1].Status) -} - -// ==================== cleanupStoredMedia 直接测试 ==================== - -func TestCleanupStoredMedia_S3Path(t *testing.T) { - // S3 清理路径:s3Storage 为 nil 时不 panic - h := &SoraClientHandler{} - // 不应 panic - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalPath(t *testing.T) { - // 本地清理路径:mediaStorage 为 nil 时不 panic - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) -} - -func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { - // upstream 类型不清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) -} - -func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { - // 空 keys 不触发清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) -} - -// ==================== DeleteGeneration: 本地存储清理路径 ==================== - -func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: []string{"video/test.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { - // MediaURLs 为空,使用 MediaURL 作为清理路径 - tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: nil, // 空 - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { - // 非本地存储类型 → 跳过清理 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeUpstream, - MediaURL: "https://upstream.com/v.mp4", - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_DeleteError(t *testing.T) { - // repo.Delete 出错 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} - repo.deleteErr = fmt.Errorf("delete failed") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== fetchUpstreamModels 测试 ==================== - -func TestFetchUpstreamModels_NilGateway(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - h := &SoraClientHandler{} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "gatewayService 未初始化") -} - -func TestFetchUpstreamModels_NoAccounts(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "选择 Sora 账号失败") -} - -func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "不支持模型同步") -} - -func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"base_url": "https://sora.test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "api_key") -} - -func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" - // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) -} - -func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "状态码 500") -} - -func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("not json")) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "解析响应失败") -} - -func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "空模型列表") -} - -func TestFetchUpstreamModels_Success(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 验证请求头 - require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) - require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - families, err := h.fetchUpstreamModels(context.Background()) - require.NoError(t, err) - require.NotEmpty(t, families) -} - -func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "未能从上游模型列表中识别") -} - -// ==================== getModelFamilies 缓存测试 ==================== - -func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { - // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 - h := &SoraClientHandler{} - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - - // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) - require.False(t, h.modelCacheUpstream) -} - -func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { - t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - require.True(t, h.modelCacheUpstream) - - // 第二次调用命中缓存 - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) -} - -func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { - // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) - h := &SoraClientHandler{ - cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, - modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 - modelCacheUpstream: false, - } - // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - // 缓存已刷新,不再是 "old" - found := false - for _, f := range families { - if f.ID == "old" { - found = true - } - } - require.False(t, found, "过期缓存应被刷新") -} - -// ==================== processGeneration: groupID 与 ForcePlatform ==================== - -func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 空账号列表 → SelectAccountForModel 失败 - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== - -func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { - // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -// ==================== Generate: CreatePending 并发限制错误 ==================== - -// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 -type stubSoraGenRepoWithAtomicCreate struct { - stubSoraGenRepo - limitErr error -} - -func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { - if r.limitErr != nil { - return r.limitErr - } - return r.stubSoraGenRepo.Create(context.Background(), gen) -} - -func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { - repo := &stubSoraGenRepoWithAtomicCreate{ - stubSoraGenRepo: *newStubSoraGenRepo(), - limitErr: service.ErrSoraGenerationConcurrencyLimit, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "3") -} - -// ==================== SaveToStorage: 配额超限 ==================== - -func TestSaveToStorage_QuotaExceeded(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户配额已满 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10, - SoraStorageUsedBytes: 10, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== - -func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== SaveToStorage: MediaURLs 全为空 ==================== - -func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: "", - MediaURLs: []string{}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "已过期") -} - -// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== - -func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== - -func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 100 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== - -func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) -} - -func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) -} - -// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== - -func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-del-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "nonexistent/video.mp4", - MediaURLs: []string{"nonexistent/video.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== CancelGeneration: 任务已结束冲突 ==================== - -func TestCancelGeneration_AlreadyCompleted(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go deleted file mode 100644 index 5e505409..00000000 --- a/backend/internal/handler/sora_gateway_handler.go +++ /dev/null @@ -1,695 +0,0 @@ -package handler - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" - "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" - - "github.com/gin-gonic/gin" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.uber.org/zap" -) - -// SoraGatewayHandler handles Sora chat completions requests -type SoraGatewayHandler struct { - gatewayService *service.GatewayService - soraGatewayService *service.SoraGatewayService - billingCacheService *service.BillingCacheService - usageRecordWorkerPool *service.UsageRecordWorkerPool - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int - streamMode string - soraTLSEnabled bool - soraMediaSigningKey string - soraMediaRoot string -} - -// NewSoraGatewayHandler creates a new SoraGatewayHandler -func NewSoraGatewayHandler( - gatewayService *service.GatewayService, - soraGatewayService *service.SoraGatewayService, - concurrencyService *service.ConcurrencyService, - billingCacheService *service.BillingCacheService, - usageRecordWorkerPool *service.UsageRecordWorkerPool, - cfg *config.Config, -) *SoraGatewayHandler { - pingInterval := time.Duration(0) - maxAccountSwitches := 3 - streamMode := "force" - soraTLSEnabled := true - signKey := "" - mediaRoot := "/app/data/sora" - if cfg != nil { - pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second - if cfg.Gateway.MaxAccountSwitches > 0 { - maxAccountSwitches = cfg.Gateway.MaxAccountSwitches - } - if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { - streamMode = mode - } - soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint - signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) - if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { - mediaRoot = root - } - } - return &SoraGatewayHandler{ - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - billingCacheService: billingCacheService, - usageRecordWorkerPool: usageRecordWorkerPool, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, - streamMode: strings.ToLower(streamMode), - soraTLSEnabled: soraTLSEnabled, - soraMediaSigningKey: signKey, - soraMediaRoot: mediaRoot, - } -} - -// ChatCompletions handles Sora /v1/chat/completions endpoint -func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { - apiKey, ok := middleware2.GetAPIKeyFromContext(c) - if !ok { - h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") - return - } - - subject, ok := middleware2.GetAuthSubjectFromContext(c) - if !ok { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") - return - } - reqLog := requestLogger( - c, - "handler.sora_gateway.chat_completions", - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - ) - - body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) - if err != nil { - if maxErr, ok := extractMaxBytesError(err); ok { - h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) - return - } - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") - return - } - if len(body) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return - } - - setOpsRequestContext(c, "", false, body) - - // 校验请求体 JSON 合法性 - if !gjson.ValidBytes(body) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } - - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - modelResult := gjson.GetBytes(body, "model") - if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") - return - } - reqModel := modelResult.String() - - msgsResult := gjson.GetBytes(body, "messages") - if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") - return - } - - clientStream := gjson.GetBytes(body, "stream").Bool() - reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) - if !clientStream { - if h.streamMode == "error" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") - return - } - var err error - body, err = sjson.SetBytes(body, "stream", true) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } - - setOpsRequestContext(c, reqModel, clientStream, body) - setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false))) - - platform := "" - if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { - platform = forced - } else if apiKey.Group != nil { - platform = apiKey.Group.Platform - } - if platform != service.PlatformSora { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") - return - } - - streamStarted := false - subscription, _ := middleware2.GetSubscriptionFromContext(c) - - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) - } else if !canWait { - reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() - - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) - if err != nil { - reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) - return - } - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) - if userReleaseFunc != nil { - defer userReleaseFunc() - } - - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) - h.handleStreamingAwareError(c, status, code, message, streamStarted) - return - } - - sessionHash := generateOpenAISessionHash(c, body) - - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 - var lastFailoverBody []byte - var lastFailoverHeaders http.Header - - for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") - if err != nil { - reqLog.Warn("sora.account_select_failed", - zap.Error(err), - zap.Int("excluded_account_count", len(failedAccountIDs)), - ) - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int("last_upstream_status", lastFailoverStatus), - } - if rayID != "" { - fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("last_upstream_content_type", contentType)) - } - reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID, account.Platform) - proxyBound := account.ProxyID != nil - proxyID := int64(0) - if account.ProxyID != nil { - proxyID = *account.ProxyID - } - tlsFingerprintEnabled := h.soraTLSEnabled - - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - reqLog.Warn("sora.account_wait_counter_increment_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - } else if !canWait { - reqLog.Info("sora.account_wait_queue_full", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), - ) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - clientStream, - &streamStarted, - ) - if err != nil { - reqLog.Warn("sora.account_slot_acquire_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_exhausted", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - switchCount++ - upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.String("upstream_error_code", upstreamErrCode), - zap.String("upstream_error_message", upstreamErrMsg), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_switching", fields...) - continue - } - reqLog.Error("sora.forward_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - return - } - - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - requestPayloadHash := service.HashUsageRequestPayload(body) - inboundEndpoint := GetInboundEndpoint(c) - upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - - // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitUsageRecordTask(func(ctx context.Context) { - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: inboundEndpoint, - UpstreamEndpoint: upstreamEndpoint, - UserAgent: userAgent, - IPAddress: clientIP, - RequestPayloadHash: requestPayloadHash, - }); err != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - zap.String("model", reqModel), - zap.Int64("account_id", account.ID), - ).Error("sora.record_usage_failed", zap.Error(err)) - } - }) - reqLog.Debug("sora.request_completed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("switch_count", switchCount), - ) - return - } -} - -func generateOpenAISessionHash(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } - if sessionID == "" { - return "" - } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) -} - -func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Any("panic", recovered), - ).Error("sora.usage_record_task_panic_recovered") - } - }() - task(ctx) -} - -func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", - fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) -} - -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { - upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) - service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") - - status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) - h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) -} - -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { - if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { - baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) - return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - - upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) - if strings.EqualFold(upstreamCode, "cf_shield_429") { - baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." - return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { - switch statusCode { - case 401, 403, 404, 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", upstreamMessage - case 429: - return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage - } - } - - switch statusCode { - case 401: - return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" - case 403: - return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" - case 404: - if strings.EqualFold(upstreamCode, "unsupported_country_code") { - return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" - } - return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" - case 429: - return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" - case 529: - return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" - case 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" - default: - return http.StatusBadGateway, "upstream_error", "Upstream request failed" - } -} - -func cloneHTTPHeaders(headers http.Header) http.Header { - if headers == nil { - return nil - } - return headers.Clone() -} - -func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { - if headers != nil { - mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) - contentType = strings.TrimSpace(headers.Get("content-type")) - if contentType == "" { - contentType = strings.TrimSpace(headers.Get("Content-Type")) - } - } - rayID = soraerror.ExtractCloudflareRayID(headers, body) - return rayID, mitigated, contentType -} - -func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) -} - -func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { - message = strings.TrimSpace(message) - if message == "" { - return false - } - if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { - lower := strings.ToLower(message) - if strings.Contains(lower, "Just a moment...`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "upstream_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare challenge") - require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") -} - -func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/", nil) - - headers := http.Header{} - headers.Set("cf-ray", "9d03b68c086027a1-SEA") - body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "rate_limit_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare shield") - require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") -} - -func TestExtractSoraFailoverHeaderInsights(t *testing.T) { - headers := http.Header{} - headers.Set("cf-mitigated", "challenge") - headers.Set("content-type", "text/html") - body := []byte(``) - - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) - require.Equal(t, "9cff2d62d83bb98d", rayID) - require.Equal(t, "challenge", mitigated) - require.Equal(t, "text/html", contentType) -} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index c7c48e14..5c945815 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") } - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { - pool := newUsageRecordTestPool(t) - h := &SoraGatewayHandler{usageRecordWorkerPool: pool} - - done := make(chan struct{}) - h.submitUsageRecordTask(func(ctx context.Context) { - close(done) - }) - - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("task not executed") - } -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { - h := &SoraGatewayHandler{} - var called atomic.Bool - - h.submitUsageRecordTask(func(ctx context.Context) { - if _, ok := ctx.Deadline(); !ok { - t.Fatal("expected deadline in fallback context") - } - called.Store(true) - }) - - require.True(t, called.Load()) -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { - h := &SoraGatewayHandler{} - require.NotPanics(t, func() { - h.submitUsageRecordTask(nil) - }) -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { - h := &SoraGatewayHandler{} - var called atomic.Bool - - require.NotPanics(t, func() { - h.submitUsageRecordTask(func(ctx context.Context) { - panic("usage task panic") - }) - }) - - h.submitUsageRecordTask(func(ctx context.Context) { - called.Store(true) - }) - require.True(t, called.Load(), "panic 后后续任务应仍可执行") -} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02ddd030..d9622594 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -33,6 +33,7 @@ func ProvideAdminHandlers( tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, + channelHandler *admin.ChannelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -59,6 +60,7 @@ func ProvideAdminHandlers( TLSFingerprintProfile: tlsFingerprintProfileHandler, APIKey: apiKeyHandler, ScheduledTest: scheduledTestHandler, + Channel: channelHandler, } } @@ -84,8 +86,6 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, - soraGatewayHandler *SoraGatewayHandler, - soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, _ *service.IdempotencyCoordinator, @@ -102,8 +102,6 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, - SoraGateway: soraGatewayHandler, - SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -121,7 +119,6 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, - NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, @@ -150,6 +147,7 @@ var ProviderSet = wire.NewSet( admin.NewTLSFingerprintProfileHandler, admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, + admin.NewChannelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 8ea87f18..ce144bb9 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -125,6 +125,7 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // ClaudeError Claude 错误响应 diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 1a0ca5bb..033dccbd 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -149,13 +149,31 @@ type GeminiCandidate struct { GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` } +// GeminiTokenDetail Gemini token 详情(按模态分类) +type GeminiTokenDetail struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + // GeminiUsageMetadata Gemini 用量元数据 type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount,omitempty"` - CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` - CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` - TotalTokenCount int `json:"totalTokenCount,omitempty"` - ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) + CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"` + PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"` +} + +// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数 +func (m *GeminiUsageMetadata) ImageOutputTokens() int { + for _, d := range m.CandidatesTokensDetails { + if d.Modality == "IMAGE" { + return d.TokenCount + } + } + return 0 } // GeminiGroundingMetadata Gemini grounding 元数据(Web Search) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index f962253d..b64b633b 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -53,6 +53,7 @@ const ( // defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0 var defaultUserAgentVersion = "1.107.0" + // defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 var defaultClientSecret string diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 3a093fe6..9850af17 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) { if RedirectURI != "http://localhost:8085/callback" { t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) } - if GetUserAgent() != "antigravity/1.20.5 windows/amd64" { + if GetUserAgent() != "antigravity/1.21.9 windows/amd64" { t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) } if SessionTTL != 30*time.Minute { diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index f12effb6..bc1fd32e 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached + usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens() } // 生成响应 ID diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index deed5f92..58982878 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -32,9 +32,10 @@ type StreamingProcessor struct { groundingChunks []GeminiGroundingChunk // 累计 usage - inputTokens int - outputTokens int - cacheReadTokens int + inputTokens int + outputTokens int + cacheReadTokens int + imageOutputTokens int } // NewStreamingProcessor 创建流式响应处理器 @@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.cacheReadTokens = cached + p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens() } // 处理 parts @@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, + ImageOutputTokens: p.imageOutputTokens, } if !p.messageStartSent { @@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached + usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens() } responseID := v1Resp.ResponseID @@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, + ImageOutputTokens: p.imageOutputTokens, } deltaEvent := map[string]any{ diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index f54a4a02..903d5b31 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -876,3 +876,182 @@ func TestChatCompletionsStreamRoundTrip(t *testing.T) { assert.Equal(t, "resp_rt", c.ID) } } + +// --------------------------------------------------------------------------- +// BufferedResponseAccumulator tests +// --------------------------------------------------------------------------- + +func TestBufferedResponseAccumulator_TextOnly(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "Hello"}) + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: ", world!"}) + + assert.True(t, acc.HasContent()) + + output := acc.BuildOutput() + require.Len(t, output, 1) + assert.Equal(t, "message", output[0].Type) + assert.Equal(t, "assistant", output[0].Role) + require.Len(t, output[0].Content, 1) + assert.Equal(t, "output_text", output[0].Content[0].Type) + assert.Equal(t, "Hello, world!", output[0].Content[0].Text) +} + +func TestBufferedResponseAccumulator_ToolCalls(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + // Add function call at output_index=1 + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 1, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_abc", + Name: "get_weather", + }, + }) + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `{"city":`, + }) + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `"NYC"}`, + }) + + assert.True(t, acc.HasContent()) + + output := acc.BuildOutput() + require.Len(t, output, 1) + assert.Equal(t, "function_call", output[0].Type) + assert.Equal(t, "call_abc", output[0].CallID) + assert.Equal(t, "get_weather", output[0].Name) + assert.Equal(t, `{"city":"NYC"}`, output[0].Arguments) +} + +func TestBufferedResponseAccumulator_Reasoning(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "Step 1: "}) + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "think about it"}) + + assert.True(t, acc.HasContent()) + + output := acc.BuildOutput() + require.Len(t, output, 1) + assert.Equal(t, "reasoning", output[0].Type) + require.Len(t, output[0].Summary, 1) + assert.Equal(t, "summary_text", output[0].Summary[0].Type) + assert.Equal(t, "Step 1: think about it", output[0].Summary[0].Text) +} + +func TestBufferedResponseAccumulator_Mixed(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + // Reasoning first + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "I thought about it."}) + + // Then text + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "The answer is 42."}) + + // Then a tool call + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 2, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_1", + Name: "verify", + }, + }) + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 2, + Delta: `{}`, + }) + + assert.True(t, acc.HasContent()) + + output := acc.BuildOutput() + // Order: reasoning → message → function_calls + require.Len(t, output, 3) + assert.Equal(t, "reasoning", output[0].Type) + assert.Equal(t, "message", output[1].Type) + assert.Equal(t, "function_call", output[2].Type) + assert.Equal(t, "The answer is 42.", output[1].Content[0].Text) + assert.Equal(t, "verify", output[2].Name) +} + +func TestBufferedResponseAccumulator_SupplementEmptyOutput(t *testing.T) { + acc := NewBufferedResponseAccumulator() + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "Hello"}) + + resp := &ResponsesResponse{ + ID: "resp_1", + Status: "completed", + Output: nil, // empty output + Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5}, + } + + acc.SupplementResponseOutput(resp) + + require.Len(t, resp.Output, 1) + assert.Equal(t, "message", resp.Output[0].Type) + assert.Equal(t, "Hello", resp.Output[0].Content[0].Text) + // Usage should be untouched + assert.Equal(t, 10, resp.Usage.InputTokens) +} + +func TestBufferedResponseAccumulator_NoSupplementWhenOutputExists(t *testing.T) { + acc := NewBufferedResponseAccumulator() + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "from deltas"}) + + resp := &ResponsesResponse{ + ID: "resp_2", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "from terminal event"}, + }, + }, + }, + } + + acc.SupplementResponseOutput(resp) + + // Output should NOT be overwritten + require.Len(t, resp.Output, 1) + assert.Equal(t, "from terminal event", resp.Output[0].Content[0].Text) +} + +func TestBufferedResponseAccumulator_EmptyDeltas(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + // Process events with empty delta — should not accumulate + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: ""}) + acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.created"}) + + assert.False(t, acc.HasContent()) + + resp := &ResponsesResponse{ID: "resp_3", Status: "completed"} + acc.SupplementResponseOutput(resp) + assert.Nil(t, resp.Output) +} + +func TestBufferedResponseAccumulator_IgnoresNonFunctionCallItems(t *testing.T) { + acc := NewBufferedResponseAccumulator() + + // output_item.added with type "message" should be ignored + acc.ProcessEvent(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 0, + Item: &ResponsesOutput{Type: "message"}, + }) + + assert.False(t, acc.HasContent()) +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 688a68eb..61b3bf9c 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strings" "time" ) @@ -372,3 +373,119 @@ func generateChatCmplID() string { _, _ = rand.Read(b) return "chatcmpl-" + hex.EncodeToString(b) } + +// --------------------------------------------------------------------------- +// BufferedResponseAccumulator: accumulates SSE delta events for non-streaming +// paths where the terminal event may have empty output. +// --------------------------------------------------------------------------- + +type bufferedFuncCall struct { + CallID string + Name string + Args strings.Builder +} + +// BufferedResponseAccumulator collects content from Responses SSE delta events +// so that non-streaming handlers can reconstruct output when the terminal event +// (response.completed / response.done) carries an empty output array. +type BufferedResponseAccumulator struct { + text strings.Builder + reasoning strings.Builder + funcCalls []bufferedFuncCall + outputIndexToFuncIdx map[int]int +} + +// NewBufferedResponseAccumulator returns an initialised accumulator. +func NewBufferedResponseAccumulator() *BufferedResponseAccumulator { + return &BufferedResponseAccumulator{ + outputIndexToFuncIdx: make(map[int]int), + } +} + +// ProcessEvent inspects a single Responses SSE event and accumulates any +// content it carries. Only delta events that contribute to the final output +// are handled; all other event types are silently ignored. +func (a *BufferedResponseAccumulator) ProcessEvent(event *ResponsesStreamEvent) { + switch event.Type { + case "response.output_text.delta": + if event.Delta != "" { + _, _ = a.text.WriteString(event.Delta) + } + case "response.output_item.added": + if event.Item != nil && event.Item.Type == "function_call" { + idx := len(a.funcCalls) + a.outputIndexToFuncIdx[event.OutputIndex] = idx + a.funcCalls = append(a.funcCalls, bufferedFuncCall{ + CallID: event.Item.CallID, + Name: event.Item.Name, + }) + } + case "response.function_call_arguments.delta": + if event.Delta != "" { + if idx, ok := a.outputIndexToFuncIdx[event.OutputIndex]; ok { + _, _ = a.funcCalls[idx].Args.WriteString(event.Delta) + } + } + case "response.reasoning_summary_text.delta": + if event.Delta != "" { + _, _ = a.reasoning.WriteString(event.Delta) + } + } +} + +// HasContent reports whether any content has been accumulated. +func (a *BufferedResponseAccumulator) HasContent() bool { + return a.text.Len() > 0 || len(a.funcCalls) > 0 || a.reasoning.Len() > 0 +} + +// BuildOutput constructs a []ResponsesOutput from the accumulated delta +// content. The order matches what ResponsesToChatCompletions expects: +// reasoning → message → function_calls. +func (a *BufferedResponseAccumulator) BuildOutput() []ResponsesOutput { + var out []ResponsesOutput + + if a.reasoning.Len() > 0 { + out = append(out, ResponsesOutput{ + Type: "reasoning", + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: a.reasoning.String(), + }}, + }) + } + + if a.text.Len() > 0 { + out = append(out, ResponsesOutput{ + Type: "message", + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: a.text.String(), + }}, + }) + } + + for i := range a.funcCalls { + out = append(out, ResponsesOutput{ + Type: "function_call", + CallID: a.funcCalls[i].CallID, + Name: a.funcCalls[i].Name, + Arguments: a.funcCalls[i].Args.String(), + }) + } + + return out +} + +// SupplementResponseOutput fills resp.Output from accumulated delta content +// when the terminal event delivered an empty output array. If resp.Output is +// already populated, this is a no-op (preserves backward compatibility). +func (a *BufferedResponseAccumulator) SupplementResponseOutput(resp *ResponsesResponse) { + if resp == nil || len(resp.Output) > 0 { + return + } + if !a.HasContent() { + return + } + resp.Output = a.BuildOutput() +} diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 882d2ebd..fac79d18 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -2,6 +2,8 @@ // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). package gemini +import "strings" + type Model struct { Name string `json:"name"` DisplayName string `json:"displayName,omitempty"` @@ -23,10 +25,27 @@ func DefaultModels() []Model { {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } +func HasFallbackModel(model string) bool { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return false + } + if !strings.HasPrefix(trimmed, "models/") { + trimmed = "models/" + trimmed + } + for _, model := range DefaultModels() { + if model.Name == trimmed { + return true + } + } + return false +} + func FallbackModelsList() ModelsListResponse { return ModelsListResponse{Models: DefaultModels()} } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go index b80047fb..1d20c0e6 100644 --- a/backend/internal/pkg/gemini/models_test.go +++ b/backend/internal/pkg/gemini/models_test.go @@ -2,7 +2,7 @@ package gemini import "testing" -func TestDefaultModels_ContainsImageModels(t *testing.T) { +func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) { t.Parallel() models := DefaultModels() @@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { required := []string{ "models/gemini-2.5-flash-image", + "models/gemini-3.1-pro-preview-customtools", "models/gemini-3.1-flash-image", } @@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { } } } + +func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) { + t.Parallel() + + if !HasFallbackModel("gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected customtools model to exist in fallback catalog") + } + if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected prefixed customtools model to exist in fallback catalog") + } + if HasFallbackModel("gemini-unknown") { + t.Fatalf("did not expect unknown model to exist in fallback catalog") + } +} diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index 6b8521bd..618b6adb 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,8 +17,6 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - // OAuth Client ID for Sora mobile flow (aligned with sora2api) - SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" @@ -39,8 +37,6 @@ const ( const ( // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. OAuthPlatformOpenAI = "openai" - // OAuthPlatformSora uses Sora OAuth client. - OAuthPlatformSora = "sora" ) // OAuthSession stores OAuth flow state for OpenAI @@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor } // OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. -// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), -// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { - switch strings.ToLower(strings.TrimSpace(platform)) { - case OAuthPlatformSora: - return ClientID, false - default: - return ClientID, true - } + return ClientID, true } // TokenRequest represents the token exchange request body diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go index 2970addf..56b42fc9 100644 --- a/backend/internal/pkg/openai/oauth_test.go +++ b/backend/internal/pkg/openai/oauth_test.go @@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) } } - -// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, -// 但不启用 codex_cli_simplified_flow。 -func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { - authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) - parsed, err := url.Parse(authURL) - if err != nil { - t.Fatalf("Parse URL failed: %v", err) - } - q := parsed.Query() - if got := q.Get("client_id"); got != ClientID { - t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) - } - if got := q.Get("codex_cli_simplified_flow"); got != "" { - t.Fatalf("codex flow should be empty for sora, got=%q", got) - } - if got := q.Get("id_token_add_organizations"); got != "true" { - t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) - } -} diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 44cddb6a..5d1f7911 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -175,6 +175,13 @@ type UserBreakdownDimension struct { ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" + // Additional filter conditions + UserID int64 // filter by user_id (>0 to enable) + APIKeyID int64 // filter by api_key_id (>0 to enable) + AccountID int64 // filter by account_id (>0 to enable) + RequestType *int16 // filter by request_type (non-nil to enable) + Stream *bool // filter by stream flag (non-nil to enable) + BillingType *int8 // filter by billing_type (non-nil to enable) } // APIKeyUsageTrendPoint represents API key usage trend data point @@ -230,6 +237,7 @@ type UsageLogFilters struct { RequestType *int16 Stream *bool BillingType *int8 + BillingMode string StartTime *time.Time EndTime *time.Time // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a12..14498715 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -468,6 +468,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati } if status != "" { switch status { + case service.StatusActive: + q = q.Where( + dbaccount.StatusEQ(status), + dbaccount.Or( + dbaccount.RateLimitResetAtIsNil(), + dbaccount.RateLimitResetAtLTE(time.Now()), + ), + ) case "rate_limited": q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) case "temp_unschedulable": @@ -1692,20 +1700,13 @@ func itoa(v int) string { } // FindByExtraField 根据 extra 字段中的键值对查找账号。 -// 该方法限定 platform='sora',避免误查询其他平台的账号。 // 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 // -// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 -// // FindByExtraField finds accounts by key-value pairs in the extra field. -// Limited to platform='sora' to avoid querying accounts from other platforms. // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). -// -// Use case: Finding Sora accounts linked via linked_openai_account_id. func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where( - dbaccount.PlatformEQ("sora"), // 限定平台为 sora dbaccount.DeletedAtIsNil(), func(s *entsql.Selector) { path := sqljson.Path(key) diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 8da30c92..f3e3f745 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -255,6 +255,22 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Equal(service.StatusDisabled, accounts[0].Status) }, }, + { + name: "filter_by_status_active_excludes_rate_limited", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive}) + rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) + err := client.Account.UpdateOneID(rateLimited.ID). + SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + }, + status: service.StatusActive, + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("active-normal", accounts[0].Name) + }, + }, { name: "filter_by_search", setup: func(client *dbent.Client) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 667193a6..b3b12e81 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, - group.FieldSoraImagePrice360, - group.FieldSoraImagePrice540, - group.FieldSoraVideoPricePerRequest, - group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, @@ -608,22 +604,20 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, - SoraStorageUsedBytes: u.SoraStorageUsedBytes, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, - SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, @@ -662,6 +651,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { SupportedModelScopes: g.SupportedModelScopes, SortOrder: g.SortOrder, AllowMessagesDispatch: g.AllowMessagesDispatch, + RequireOAuthOnly: g.RequireOauthOnly, + RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go new file mode 100644 index 00000000..1e2c2e4c --- /dev/null +++ b/backend/internal/repository/channel_repo.go @@ -0,0 +1,461 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +type channelRepository struct { + db *sql.DB +} + +// NewChannelRepository 创建渠道数据访问实例 +func NewChannelRepository(db *sql.DB) service.ChannelRepository { + return &channelRepository{db: db} +} + +// runInTx 在事务中执行 fn,成功 commit,失败 rollback。 +func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } + err = tx.QueryRowContext(ctx, + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, created_at, updated_at`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, + ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("insert channel: %w", err) + } + + // 设置分组关联 + if len(channel.GroupIDs) > 0 { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 设置模型定价 + if len(channel.ModelPricing) > 0 { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { + ch := &service.Channel{} + var modelMappingJSON []byte + err := r.db.QueryRowContext(ctx, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at + FROM channels WHERE id = $1`, id, + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) + if err == sql.ErrNoRows { + return nil, service.ErrChannelNotFound + } + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + + groupIDs, err := r.GetGroupIDs(ctx, id) + if err != nil { + return nil, err + } + ch.GroupIDs = groupIDs + + pricing, err := r.ListModelPricing(ctx, id) + if err != nil { + return nil, err + } + ch.ModelPricing = pricing + + return ch, nil +} + +func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } + result, err := tx.ExecContext(ctx, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() + WHERE id = $7`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, + ) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("update channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + + // 更新分组关联 + if channel.GroupIDs != nil { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 更新模型定价 + if channel.ModelPricing != nil { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) Delete(ctx context.Context, id int64) error { + result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + return nil +} + +func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) { + where := []string{"1=1"} + args := []any{} + argIdx := 1 + + if status != "" { + where = append(where, fmt.Sprintf("c.status = $%d", argIdx)) + args = append(args, status) + argIdx++ + } + if search != "" { + where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx)) + args = append(args, "%"+escapeLike(search)+"%") + argIdx++ + } + + whereClause := strings.Join(where, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause) + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count channels: %w", err) + } + + pageSize := params.Limit() // 约束在 [1, 100] + page := params.Page + if page < 1 { + page = 1 + } + offset := (page - 1) * pageSize + + // 查询 channel 列表 + dataQuery := fmt.Sprintf( + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at + FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`, + whereClause, argIdx, argIdx+1, + ) + args = append(args, pageSize, offset) + + rows, err := r.db.QueryContext(ctx, dataQuery, args...) + if err != nil { + return nil, nil, fmt.Errorf("query channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, nil, fmt.Errorf("scan channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate channels: %w", err) + } + + // 批量加载分组 ID 和模型定价(避免 N+1) + if len(channelIDs) > 0 { + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + } + + pages := 0 + if total > 0 { + pages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + + paginationResult := &pagination.PaginationResult{ + Total: total, + Page: page, + PageSize: pageSize, + Pages: pages, + } + + return channels, paginationResult, nil +} + +func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, + ) + if err != nil { + return nil, fmt.Errorf("query all channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate channels: %w", err) + } + + if len(channelIDs) == 0 { + return channels, nil + } + + // 批量加载分组 ID + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, err + } + + // 批量加载模型定价 + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, err + } + + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + + return channels, nil +} + +// --- 批量加载辅助方法 --- + +// batchLoadGroupIDs 批量加载多个渠道的分组 ID +func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT channel_id, group_id FROM channel_groups + WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load group ids: %w", err) + } + defer func() { _ = rows.Close() }() + + groupMap := make(map[int64][]int64, len(channelIDs)) + for rows.Next() { + var channelID, groupID int64 + if err := rows.Scan(&channelID, &groupID); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + groupMap[channelID] = append(groupMap[channelID], groupID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return groupMap, nil +} + +func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name, + ).Scan(&exists) + return exists, err +} + +func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID, + ).Scan(&exists) + return exists, err +} + +// --- 分组关联 --- + +func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("get group ids: %w", err) + } + defer func() { _ = rows.Close() }() + + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return ids, nil +} + +func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + return setGroupIDsTx(ctx, r.db, channelID, groupIDs) +} + +func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + var channelID int64 + err := r.db.QueryRowContext(ctx, + `SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID, + ).Scan(&channelID) + if err == sql.ErrNoRows { + return 0, nil + } + return channelID, err +} + +func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if len(groupIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`, + pq.Array(groupIDs), channelID, + ) + if err != nil { + return nil, fmt.Errorf("get groups in other channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var conflicting []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan conflicting group id: %w", err) + } + conflicting = append(conflicting, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate conflicting group ids: %w", err) + } + return conflicting, nil +} + +// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节 +// 格式:{"platform": {"src": "dst"}, ...} +func marshalModelMapping(m map[string]map[string]string) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal model_mapping: %w", err) + } + return data, nil +} + +// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping +func unmarshalModelMapping(data []byte) map[string]map[string]string { + if len(data) == 0 { + return nil + } + var m map[string]map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + +// GetGroupPlatforms 批量查询分组 ID 对应的平台 +func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + if len(groupIDs) == 0 { + return make(map[int64]string), nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, platform FROM groups WHERE id = ANY($1)`, + pq.Array(groupIDs), + ) + if err != nil { + return nil, fmt.Errorf("get group platforms: %w", err) + } + defer rows.Close() //nolint:errcheck + + result := make(map[int64]string, len(groupIDs)) + for rows.Next() { + var id int64 + var platform string + if err := rows.Scan(&id, &platform); err != nil { + return nil, fmt.Errorf("scan group platform: %w", err) + } + result[id] = platform + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group platforms: %w", err) + } + return result, nil +} diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go new file mode 100644 index 00000000..6dcf3c91 --- /dev/null +++ b/backend/internal/repository/channel_repo_pricing.go @@ -0,0 +1,291 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// --- 模型定价 --- + +func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("list model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + result, pricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + if len(pricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs) + if err != nil { + return nil, err + } + for i := range result { + result[i].Intervals = intervalMap[result[i].ID] + } + } + + return result, nil +} + +func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + return createModelPricingExec(ctx, r.db, pricing) +} + +func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + result, err := r.db.ExecContext(ctx, + `UPDATE channel_model_pricing + SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW() + WHERE id = $10`, + modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID, + ) + if err != nil { + return fmt.Errorf("update model pricing: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return fmt.Errorf("pricing entry not found: %d", pricing.ID) + } + return nil +} + +func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + _, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete model pricing: %w", err) + } + return nil +} + +func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + return replaceModelPricingTx(ctx, tx, channelID, pricingList) + }) +} + +// --- 批量加载辅助方法 --- + +// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) +func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + allPricing, allPricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + // 按 channelID 分组 + pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs)) + for _, p := range allPricing { + pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p) + } + + // 批量加载所有区间 + if len(allPricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for chID := range pricingMap { + for i := range pricingMap[chID] { + pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID] + } + } + } + + return pricingMap, nil +} + +// batchLoadIntervals 批量加载多个定价条目的区间 +func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs)) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan interval: %w", err) + } + intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate intervals: %w", err) + } + return intervalMap, nil +} + +// --- 共享 scan 辅助 --- + +// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表 +func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) { + var result []service.ChannelModelPricing + var pricingIDs []int64 + for rows.Next() { + var p service.ChannelModelPricing + var modelsJSON []byte + if err := rows.Scan( + &p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode, + &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, + &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan model pricing: %w", err) + } + if err := json.Unmarshal(modelsJSON, &p.Models); err != nil { + p.Models = []string{} + } + pricingIDs = append(pricingIDs, p.ID) + result = append(result, p) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate model pricing: %w", err) + } + return result, pricingIDs, nil +} + +// --- 事务内辅助方法 --- + +// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口 +type dbExec interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old group associations: %w", err) + } + if len(groupIDs) == 0 { + return nil + } + _, err := exec.ExecContext(ctx, + `INSERT INTO channel_groups (channel_id, group_id) + SELECT $1, unnest($2::bigint[])`, + channelID, pq.Array(groupIDs), + ) + if err != nil { + return fmt.Errorf("insert group associations: %w", err) + } + return nil +} + +func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := pricing.Platform + if platform == "" { + platform = "anthropic" + } + err = exec.QueryRowContext(ctx, + `INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + pricing.ChannelID, platform, modelsJSON, billingMode, + pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, + ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) + if err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + + for i := range pricing.Intervals { + pricing.Intervals[i].PricingID = pricing.ID + if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil { + return err + } + } + + return nil +} + +func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error { + return exec.QueryRowContext(ctx, + `INSERT INTO channel_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old model pricing: %w", err) + } + for i := range pricingList { + pricingList[i].ChannelID = channelID + if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + } + return nil +} + +// isUniqueViolation 检查 pq 唯一约束违反错误 +func isUniqueViolation(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr != nil { + return pqErr.Code == "23505" + } + return false +} + +// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符 +func escapeLike(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `%`, `\%`) + s = strings.ReplaceAll(s, `_`, `\_`) + return s +} diff --git a/backend/internal/repository/channel_repo_test.go b/backend/internal/repository/channel_repo_test.go new file mode 100644 index 00000000..5a59948d --- /dev/null +++ b/backend/internal/repository/channel_repo_test.go @@ -0,0 +1,227 @@ +//go:build unit + +package repository + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/lib/pq" + "github.com/stretchr/testify/require" +) + +// --- marshalModelMapping --- + +func TestMarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input map[string]map[string]string + wantJSON string // expected JSON output (exact match) + }{ + { + name: "empty map", + input: map[string]map[string]string{}, + wantJSON: "{}", + }, + { + name: "nil map", + input: nil, + wantJSON: "{}", + }, + { + name: "populated map", + input: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + }, + { + name: "nested values", + input: map[string]map[string]string{ + "openai": {"*": "gpt-5.4"}, + "anthropic": {"claude-old": "claude-new"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := marshalModelMapping(tt.input) + require.NoError(t, err) + + if tt.wantJSON != "" { + require.Equal(t, []byte(tt.wantJSON), result) + } else { + // round-trip: unmarshal and compare with input + var parsed map[string]map[string]string + require.NoError(t, json.Unmarshal(result, &parsed)) + require.Equal(t, tt.input, parsed) + } + }) + } +} + +// --- unmarshalModelMapping --- + +func TestUnmarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input []byte + wantNil bool + want map[string]map[string]string + }{ + { + name: "nil data", + input: nil, + wantNil: true, + }, + { + name: "empty data", + input: []byte{}, + wantNil: true, + }, + { + name: "invalid JSON", + input: []byte("not-json"), + wantNil: true, + }, + { + name: "type error - number", + input: []byte("42"), + wantNil: true, + }, + { + name: "type error - array", + input: []byte("[1,2,3]"), + wantNil: true, + }, + { + name: "valid JSON", + input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`), + want: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + "anthropic": {"old": "new"}, + }, + }, + { + name: "empty object", + input: []byte("{}"), + want: map[string]map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := unmarshalModelMapping(tt.input) + if tt.wantNil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + require.Equal(t, tt.want, result) + } + }) + } +} + +// --- escapeLike --- + +func TestEscapeLike(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no special chars", + input: "hello", + want: "hello", + }, + { + name: "backslash", + input: `a\b`, + want: `a\\b`, + }, + { + name: "percent", + input: "50%", + want: `50\%`, + }, + { + name: "underscore", + input: "a_b", + want: `a\_b`, + }, + { + name: "all special chars", + input: `a\b%c_d`, + want: `a\\b\%c\_d`, + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "consecutive special chars", + input: "%_%", + want: `\%\_\%`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, escapeLike(tt.input)) + }) + } +} + +// --- isUniqueViolation --- + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "unique violation code 23505", + err: &pq.Error{Code: "23505"}, + want: true, + }, + { + name: "different pq error code", + err: &pq.Error{Code: "23503"}, + want: false, + }, + { + name: "non-pq error", + err: errors.New("some generic error"), + want: false, + }, + { + name: "typed nil pq.Error", + err: func() error { + var pqErr *pq.Error + return pqErr + }(), + want: false, + }, + { + name: "bare nil", + err: nil, + want: false, + }, + { + name: "wrapped pq error with 23505", + err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isUniqueViolation(tt.err)) + }) + } +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 674c655b..a075b586 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -49,18 +49,15 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). - SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). - SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). - SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 设置模型路由配置 @@ -120,16 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). - SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). - SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). - SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). + SetRequireOauthOnly(groupIn.RequireOAuthOnly). + SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 44fa291b..c1901d71 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) } -// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 -func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { - var seenClientIDs []string - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - clientID := r.PostForm.Get("client_id") - seenClientIDs = append(seenClientIDs, clientID) - if clientID == openai.SoraClientID { - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - - resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) - require.NoError(s.T(), err, "RefreshTokenWithClientID") - require.Equal(s.T(), "at-sora", resp.AccessToken) - require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) -} - func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { const customClientID = "custom-client-id" var seenClientIDs []string @@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { } func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { - wantClientID := openai.SoraClientID + wantClientID := "custom-exchange-client-id" errCh := make(chan string, 1) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseForm() diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go deleted file mode 100644 index ad2ae638..00000000 --- a/backend/internal/repository/sora_account_repo.go +++ /dev/null @@ -1,98 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "errors" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraAccountRepository 实现 service.SoraAccountRepository 接口。 -// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 -// -// 设计说明: -// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 -// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 -// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 -type soraAccountRepository struct { - sql *sql.DB -} - -// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 -func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { - return &soraAccountRepository{sql: sqlDB} -} - -// Upsert 创建或更新 Sora 账号扩展信息 -// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert -func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { - accessToken, accessOK := updates["access_token"].(string) - refreshToken, refreshOK := updates["refresh_token"].(string) - sessionToken, sessionOK := updates["session_token"].(string) - - if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { - if !sessionOK { - return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") - } - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_accounts - SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, - updated_at = NOW() - WHERE account_id = $1 - `, accountID, sessionToken) - if err != nil { - return err - } - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") - } - return nil - } - - _, err := r.sql.ExecContext(ctx, ` - INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) - VALUES ($1, $2, $3, $4, NOW(), NOW()) - ON CONFLICT (account_id) DO UPDATE SET - access_token = EXCLUDED.access_token, - refresh_token = EXCLUDED.refresh_token, - session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, - updated_at = NOW() - `, accountID, accessToken, refreshToken, sessionToken) - return err -} - -// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 -func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { - rows, err := r.sql.QueryContext(ctx, ` - SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') - FROM sora_accounts - WHERE account_id = $1 - `, accountID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return nil, nil // 记录不存在 - } - - var sa service.SoraAccount - if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { - return nil, err - } - return &sa, nil -} - -// Delete 删除 Sora 账号扩展信息 -func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { - _, err := r.sql.ExecContext(ctx, ` - DELETE FROM sora_accounts WHERE account_id = $1 - `, accountID) - return err -} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go deleted file mode 100644 index aaf3cb2f..00000000 --- a/backend/internal/repository/sora_generation_repo.go +++ /dev/null @@ -1,419 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 -// 使用原生 SQL 操作 sora_generations 表。 -type soraGenerationRepository struct { - sql *sql.DB -} - -// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 -func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { - return &soraGenerationRepository{sql: sqlDB} -} - -func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - err := r.sql.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt) - return err -} - -// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 -func (r *soraGenerationRepository) CreatePendingWithLimit( - ctx context.Context, - gen *service.SoraGeneration, - activeStatuses []string, - maxActive int64, -) error { - if gen == nil { - return fmt.Errorf("generation is nil") - } - if maxActive <= 0 { - return r.Create(ctx, gen) - } - if len(activeStatuses) == 0 { - activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} - } - - tx, err := r.sql.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - - // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 - if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { - return err - } - - placeholders := make([]string, len(activeStatuses)) - args := make([]any, 0, 1+len(activeStatuses)) - args = append(args, gen.UserID) - for i, s := range activeStatuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - countQuery := fmt.Sprintf( - `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, - strings.Join(placeholders, ","), - ) - var activeCount int64 - if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { - return err - } - if activeCount >= maxActive { - return service.ErrSoraGenerationConcurrencyLimit - } - - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - if err := tx.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt); err != nil { - return err - } - - return tx.Commit() -} - -func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - err := r.sql.QueryRowContext(ctx, ` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations WHERE id = $1 - `, id).Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("生成记录不存在") - } - return nil, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - return gen, nil -} - -func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - var completedAt *time.Time - if gen.CompletedAt != nil { - completedAt = gen.CompletedAt - } - - _, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations SET - status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, - storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, - error_message = $9, completed_at = $10 - WHERE id = $1 - `, - gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, - gen.ErrorMessage, completedAt, - ) - return err -} - -// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 -func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, upstream_task_id = $3 - WHERE id = $1 AND status = $4 - `, - id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 -func (r *soraGenerationRepository) UpdateCompletedIfActive( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, - completedAt time.Time, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - media_url = $3, - media_urls = $4, - file_size_bytes = $5, - storage_type = $6, - s3_object_keys = $7, - error_message = '', - completed_at = $8 - WHERE id = $1 AND status IN ($9, $10) - `, - id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, - storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 -func (r *soraGenerationRepository) UpdateFailedIfActive( - ctx context.Context, - id int64, - errMsg string, - completedAt time.Time, -) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - error_message = $3, - completed_at = $4 - WHERE id = $1 AND status IN ($5, $6) - `, - id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 -func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, completed_at = $3 - WHERE id = $1 AND status IN ($4, $5) - `, - id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 -func (r *soraGenerationRepository) UpdateStorageIfCompleted( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET media_url = $2, - media_urls = $3, - file_size_bytes = $4, - storage_type = $5, - s3_object_keys = $6 - WHERE id = $1 AND status = $7 - `, - id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { - _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) - return err -} - -func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - // 构建 WHERE 条件 - conditions := []string{"user_id = $1"} - args := []any{params.UserID} - argIdx := 2 - - if params.Status != "" { - // 支持逗号分隔的多状态 - statuses := strings.Split(params.Status, ",") - placeholders := make([]string, len(statuses)) - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) - } - if params.StorageType != "" { - storageTypes := strings.Split(params.StorageType, ",") - placeholders := make([]string, len(storageTypes)) - for i, s := range storageTypes { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) - } - if params.MediaType != "" { - conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) - args = append(args, params.MediaType) - argIdx++ - } - - whereClause := "WHERE " + strings.Join(conditions, " AND ") - - // 计数 - var total int64 - countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) - if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { - return nil, 0, err - } - - // 分页查询 - offset := (params.Page - 1) * params.PageSize - listQuery := fmt.Sprintf(` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations %s - ORDER BY created_at DESC - LIMIT $%d OFFSET $%d - `, whereClause, argIdx, argIdx+1) - args = append(args, params.PageSize, offset) - - rows, err := r.sql.QueryContext(ctx, listQuery, args...) - if err != nil { - return nil, 0, err - } - defer func() { - _ = rows.Close() - }() - - var results []*service.SoraGeneration - for rows.Next() { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - if err := rows.Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ); err != nil { - return nil, 0, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - results = append(results, gen) - } - - return results, total, rows.Err() -} - -func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { - if len(statuses) == 0 { - return 0, nil - } - - placeholders := make([]string, len(statuses)) - args := []any{userID} - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - - var count int64 - query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) - err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) - return count, err -} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index e4da825b..d7bcd094 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{ "integer", // cache_read_tokens "integer", // cache_creation_5m_tokens "integer", // cache_creation_1h_tokens + "integer", // image_output_tokens + "numeric", // image_output_cost "numeric", // input_cost "numeric", // output_cost "numeric", // cache_creation_cost @@ -71,12 +73,15 @@ var usageLogInsertArgTypes = [...]string{ "text", // ip_address "integer", // image_count "text", // image_size - "text", // media_type "text", // service_tier "text", // reasoning_effort "text", // inbound_endpoint "text", // upstream_endpoint "boolean", // cache_ttl_overridden + "bigint", // channel_id + "text", // model_mapping_chain + "text", // billing_tier + "text", // billing_mode "timestamptz", // created_at } @@ -326,6 +331,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -344,20 +351,23 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, - $14, $15, - $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -758,6 +768,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -776,16 +788,19 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*39) + args := make([]any, 0, len(keys)*46) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -829,6 +844,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -847,12 +864,15 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) SELECT @@ -871,6 +891,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -889,12 +911,15 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -953,6 +978,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -971,16 +998,19 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*40) + args := make([]any, 0, len(preparedList)*45) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1021,6 +1051,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1039,12 +1071,15 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) SELECT @@ -1063,6 +1098,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1081,12 +1118,15 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1113,6 +1153,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1131,20 +1173,23 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, - $14, $15, - $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1171,11 +1216,14 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) - mediaType := nullString(log.MediaType) serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + channelID := nullInt64(log.ChannelID) + modelMappingChain := nullString(log.ModelMappingChain) + billingTier := nullString(log.BillingTier) + billingMode := nullString(log.BillingMode) requestedModel := strings.TrimSpace(log.RequestedModel) if requestedModel == "" { requestedModel = strings.TrimSpace(log.Model) @@ -1208,6 +1256,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -1226,12 +1276,15 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ipAddress, log.ImageCount, imageSize, - mediaType, serviceTier, reasoningEffort, inboundEndpoint, upstreamEndpoint, log.CacheTTLOverridden, + channelID, + modelMappingChain, + billingTier, + billingMode, createdAt, }, } @@ -2564,8 +2617,8 @@ type UsageLogFilters = usagestats.UsageLogFilters // ListWithFilters lists usage logs with optional filters (for admin) func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { - conditions := make([]string, 0, 8) - args := make([]any, 0, 8) + conditions := make([]string, 0, 9) + args := make([]any, 0, 9) if filters.UserID > 0 { conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) @@ -2589,6 +2642,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } + if filters.BillingMode != "" { + conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) + args = append(args, filters.BillingMode) + } if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -3096,6 +3153,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) args = append(args, dim.Endpoint) } + if dim.UserID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, dim.UserID) + } + if dim.APIKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, dim.APIKeyID) + } + if dim.AccountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, dim.AccountID) + } + if dim.RequestType != nil { + query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1) + args = append(args, *dim.RequestType) + } + if dim.Stream != nil { + query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1) + args = append(args, *dim.Stream) + } + if dim.BillingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, *dim.BillingType) + } query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" if limit > 0 { @@ -3256,6 +3337,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } + if filters.BillingMode != "" { + conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) + args = append(args, filters.BillingMode) + } if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -3935,6 +4020,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e cacheReadTokens int cacheCreation5m int cacheCreation1h int + imageOutputTokens int + imageOutputCost float64 inputCost float64 outputCost float64 cacheCreationCost float64 @@ -3953,12 +4040,15 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString - mediaType sql.NullString serviceTier sql.NullString reasoningEffort sql.NullString inboundEndpoint sql.NullString upstreamEndpoint sql.NullString cacheTTLOverridden bool + channelID sql.NullInt64 + modelMappingChain sql.NullString + billingTier sql.NullString + billingMode sql.NullString createdAt time.Time ) @@ -3979,6 +4069,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &cacheReadTokens, &cacheCreation5m, &cacheCreation1h, + &imageOutputTokens, + &imageOutputCost, &inputCost, &outputCost, &cacheCreationCost, @@ -3997,12 +4089,15 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, - &mediaType, &serviceTier, &reasoningEffort, &inboundEndpoint, &upstreamEndpoint, &cacheTTLOverridden, + &channelID, + &modelMappingChain, + &billingTier, + &billingMode, &createdAt, ); err != nil { return nil, err @@ -4021,6 +4116,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e CacheReadTokens: cacheReadTokens, CacheCreation5mTokens: cacheCreation5m, CacheCreation1hTokens: cacheCreation1h, + ImageOutputTokens: imageOutputTokens, + ImageOutputCost: imageOutputCost, InputCost: inputCost, OutputCost: outputCost, CacheCreationCost: cacheCreationCost, @@ -4069,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } - if mediaType.Valid { - log.MediaType = &mediaType.String - } if serviceTier.Valid { log.ServiceTier = &serviceTier.String } @@ -4087,6 +4181,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamModel.Valid { log.UpstreamModel = &upstreamModel.String } + if channelID.Valid { + value := channelID.Int64 + log.ChannelID = &value + } + if modelMappingChain.Valid { + log.ModelMappingChain = &modelMappingChain.String + } + if billingTier.Valid { + log.BillingTier = &billingTier.String + } + if billingMode.Valid { + log.BillingMode = &billingMode.String + } return log, nil } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index ebc8929a..ce0c5f00 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -74,12 +76,15 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // ip_address log.ImageCount, sqlmock.AnyArg(), // image_size - sqlmock.AnyArg(), // media_type sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // upstream_endpoint log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) @@ -129,6 +134,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -147,12 +154,15 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), log.ImageCount, sqlmock.AnyArg(), - sqlmock.AnyArg(), serviceTier, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) @@ -439,6 +449,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 4, // cache_read_tokens 5, // cache_creation_5m_tokens 6, // cache_creation_1h_tokens + 0, // image_output_tokens + 0.0, // image_output_cost 0.1, // input_cost 0.2, // output_cost 0.3, // cache_creation_cost @@ -457,12 +469,15 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -487,6 +502,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, + 0, 0.0, // image_output_tokens, image_output_cost 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 1.0, sql.NullFloat64{}, @@ -500,12 +516,15 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "flex"}, sql.NullString{}, sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -530,6 +549,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, + 0, 0.0, // image_output_tokens, image_output_cost 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 1.0, sql.NullFloat64{}, @@ -543,12 +563,15 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 575754e0..06c79113 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). - SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } -// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 -func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { - if deltaBytes <= 0 { - user, err := r.GetByID(ctx, userID) - if err != nil { - return 0, err - } - return user.SoraStorageUsedBytes, nil - } - var newUsed int64 - err := scanSingleRow(ctx, r.sql, ` - UPDATE users - SET sora_storage_used_bytes = sora_storage_used_bytes + $2 - WHERE id = $1 - AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) - RETURNING sora_storage_used_bytes - `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) - if err == nil { - return newUsed, nil - } - if errors.Is(err, sql.ErrNoRows) { - // 区分用户不存在和配额冲突 - exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) - if existsErr != nil { - return 0, existsErr - } - if !exists { - return 0, service.ErrUserNotFound - } - return 0, service.ErrSoraStorageQuotaExceeded - } - return 0, err -} - -// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 -func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { - if deltaBytes <= 0 { - user, err := r.GetByID(ctx, userID) - if err != nil { - return 0, err - } - return user.SoraStorageUsedBytes, nil - } - var newUsed int64 - err := scanSingleRow(ctx, r.sql, ` - UPDATE users - SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) - WHERE id = $1 - RETURNING sora_storage_used_bytes - `, []any{userID, deltaBytes}, &newUsed) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return 0, service.ErrUserNotFound - } - return 0, err - } - return newUsed, nil -} - func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 49d47bf6..657e3ed6 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, - NewSoraAccountRepository, // Sora 账号扩展表仓储 NewScheduledTestPlanRepository, // 定时测试计划仓储 NewScheduledTestResultRepository, // 定时测试结果仓储 NewProxyRepository, @@ -74,6 +73,7 @@ var ProviderSet = wire.NewSet( NewUserGroupRateRepository, NewErrorPassthroughRepository, NewTLSFingerprintProfileRepository, + NewChannelRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index ac4e05de..d412ea34 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -204,16 +204,13 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "sora_image_price_360": null, - "sora_image_price_540": null, - "sora_storage_quota_bytes": 0, - "sora_video_price_per_request": null, - "sora_video_price_per_request_hd": null, "claude_code_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "allow_messages_dispatch": false, + "require_oauth_only": false, + "require_privacy_set": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -530,7 +527,6 @@ func TestAPIContracts(t *testing.T) { "fallback_model_openai": "gpt-4o", "enable_identity_patch": true, "identity_patch_prompt": "", - "sora_client_enabled": false, "invitation_code_enabled": false, "home_content": "", "hide_ccs_import_button": false, @@ -651,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index d9ec951e..73210bfc 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool { return strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/responses") } diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 99701531..d60a142c 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -109,7 +109,6 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) - routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e04dae85..b921da95 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,8 +34,6 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) - // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) - registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -87,6 +85,9 @@ func RegisterAdminRoutes( // 定时测试计划 registerScheduledTestRoutes(admin, h) + + // 渠道管理 + registerChannelRoutes(admin, h) } } @@ -318,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } -func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - sora := admin.Group("/sora") - { - sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) - sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) - sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) - sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) - sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) - sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) - sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) - } -} - func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -419,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Beta 策略配置 adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) - // Sora S3 存储配置 - adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) - adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) - adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) - adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) - adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) - adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) - adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) - adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) } } @@ -567,3 +546,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) } } + +func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + channels := admin.Group("/channels") + { + channels.GET("", h.Admin.Channel.List) + channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) + channels.GET("/:id", h.Admin.Channel.GetByID) + channels.POST("", h.Admin.Channel.Create) + channels.PUT("/:id", h.Admin.Channel.Update) + channels.DELETE("/:id", h.Admin.Channel.Delete) + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 072cfdee..cbf98293 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -23,11 +23,6 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) - soraMaxBodySize := cfg.Gateway.SoraMaxBodySize - if soraMaxBodySize <= 0 { - soraMaxBodySize = cfg.Gateway.MaxBodySize - } - soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) endpointNorm := handler.InboundEndpointMiddleware() @@ -163,28 +158,6 @@ func RegisterGatewayRoutes( antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } - // Sora 专用路由(强制使用 sora 平台) - soraV1 := r.Group("/sora/v1") - soraV1.Use(soraBodyLimit) - soraV1.Use(clientRequestID) - soraV1.Use(opsErrorLogger) - soraV1.Use(endpointNorm) - soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) - soraV1.Use(gin.HandlerFunc(apiKeyAuth)) - soraV1.Use(requireGroupAnthropic) - { - soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) - soraV1.GET("/models", h.Gateway.Models) - } - - // Sora 媒体代理(可选 API Key 验证) - if cfg.Gateway.SoraMediaRequireAPIKey { - r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) - } else { - r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) - } - // Sora 媒体代理(签名 URL,无需 API Key) - r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } // getGroupPlatform extracts the group platform from the API Key stored in context. diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go index 00edd31b..4d65a626 100644 --- a/backend/internal/server/routes/gateway_test.go +++ b/backend/internal/server/routes/gateway_test.go @@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine { &handler.Handlers{ Gateway: &handler.GatewayHandler{}, OpenAIGateway: &handler.OpenAIGatewayHandler{}, - SoraGateway: &handler.SoraGatewayHandler{}, }, servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { c.Next() diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go deleted file mode 100644 index 13fceb81..00000000 --- a/backend/internal/server/routes/sora_client.go +++ /dev/null @@ -1,36 +0,0 @@ -package routes - -import ( - "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 -func RegisterSoraClientRoutes( - v1 *gin.RouterGroup, - h *handler.Handlers, - jwtAuth middleware.JWTAuthMiddleware, - settingService *service.SettingService, -) { - if h.SoraClient == nil { - return - } - - authenticated := v1.Group("/sora") - authenticated.Use(gin.HandlerFunc(jwtAuth)) - authenticated.Use(middleware.BackendModeUserGuard(settingService)) - { - authenticated.POST("/generate", h.SoraClient.Generate) - authenticated.GET("/generations", h.SoraClient.ListGenerations) - authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) - authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) - authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) - authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) - authenticated.GET("/quota", h.SoraClient.GetQuota) - authenticated.GET("/models", h.SoraClient.GetModels) - authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) - } -} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a1449ffd..512195e3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -141,6 +141,21 @@ func (a *Account) IsOAuth() bool { return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken } +// IsPrivacySet 检查账号的 privacy 是否已成功设置。 +// OpenAI: privacy_mode == "training_off" +// Antigravity: privacy_mode == "privacy_set" +// 其他平台: 无 privacy 概念,始终返回 true +func (a *Account) IsPrivacySet() bool { + switch a.Platform { + case PlatformOpenAI: + return a.getExtraString("privacy_mode") == PrivacyModeTrainingOff + case PlatformAntigravity: + return a.getExtraString("privacy_mode") == AntigravityPrivacySet + default: + return true + } +} + func (a *Account) IsGemini() bool { return a.Platform == PlatformGemini } @@ -500,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st } } +func normalizeRequestedModelForLookup(platform, requestedModel string) string { + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return "" + } + if platform != PlatformGemini && platform != PlatformAntigravity { + return trimmed + } + if trimmed == "gemini-3.1-pro-preview-customtools" { + return "gemini-3.1-pro-preview" + } + return trimmed +} + +func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool { + if requestedModel == "" { + return false + } + if _, exists := mapping[requestedModel]; exists { + return true + } + for pattern := range mapping { + if matchWildcard(pattern, requestedModel) { + return true + } + } + return false +} + +func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) { + if requestedModel == "" { + return "", false + } + if mappedModel, exists := mapping[requestedModel]; exists { + return mappedModel, true + } + return matchWildcardMappingResult(mapping, requestedModel) +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -507,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool { if len(mapping) == 0 { return true // 无映射 = 允许所有 } - // 精确匹配 - if _, exists := mapping[requestedModel]; exists { + if mappingSupportsRequestedModel(mapping, requestedModel) { return true } - // 通配符匹配 - for pattern := range mapping { - if matchWildcard(pattern, requestedModel) { - return true - } - } - return false + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized) } // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) @@ -534,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, if len(mapping) == 0 { return requestedModel, false } - // 精确匹配优先 - if mappedModel, exists := mapping[requestedModel]; exists { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched { return mappedModel, true } - // 通配符匹配(最长优先) - return matchWildcardMappingResult(mapping, requestedModel) + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + if normalized != requestedModel { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched { + return mappedModel, true + } + } + return requestedModel, false } func (a *Account) GetBaseURL() string { @@ -1727,22 +1779,47 @@ func (a *Account) GetRPMStrategy() string { } // GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 -// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +// Cache-driven: buffer = concurrency + maxSessions(覆盖幽灵窗口 + 稳态会话需求) +// floor = baseRPM / 5(向后兼容 maxSessions=0 且 concurrency=0 场景) func (a *Account) GetRPMStickyBuffer() int { if a.Extra == nil { return 0 } + + // 手动 override 最高优先级 if v, ok := a.Extra["rpm_sticky_buffer"]; ok { val := parseExtraInt(v) if val > 0 { return val } } + base := a.GetBaseRPM() - buffer := base / 5 - if buffer < 1 && base > 0 { - buffer = 1 + if base <= 0 { + return 0 } + + // Cache-driven buffer = concurrency + maxSessions + conc := a.Concurrency + if conc < 0 { + conc = 0 + } + sess := a.GetMaxSessions() + if sess < 0 { + sess = 0 + } + + buffer := conc + sess + + // floor: 向后兼容 + floor := base / 5 + if floor < 1 { + floor = 1 + } + if buffer < floor { + buffer = floor + } + return buffer } diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go index 9d91f3e0..40298263 100644 --- a/backend/internal/service/account_rpm_test.go +++ b/backend/internal/service/account_rpm_test.go @@ -90,28 +90,47 @@ func TestCheckRPMSchedulability(t *testing.T) { func TestGetRPMStickyBuffer(t *testing.T) { tests := []struct { - name string - extra map[string]any - expected int + name string + concurrency int + extra map[string]any + expected int }{ - {"nil extra", nil, 0}, - {"no keys", map[string]any{}, 0}, - {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, - {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, - {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, - {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, - {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, - {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, - {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, - {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, - {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, - {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, - {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, - {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + // 基础退化 + {"nil extra", 0, nil, 0}, + {"no keys", 0, map[string]any{}, 0}, + {"base_rpm=0", 0, map[string]any{"base_rpm": 0}, 0}, + + // 新公式: concurrency + maxSessions, floor = base/5 + {"conc=3 sess=10 → 13", 3, map[string]any{"base_rpm": 15, "max_sessions": 10}, 13}, + {"conc=2 sess=5 → 7", 2, map[string]any{"base_rpm": 10, "max_sessions": 5}, 7}, + {"conc=3 sess=15 → 18", 3, map[string]any{"base_rpm": 30, "max_sessions": 15}, 18}, + + // floor 生效 (conc+sess < base/5) + {"conc=0 sess=0 base=15 → floor 3", 0, map[string]any{"base_rpm": 15}, 3}, + {"conc=0 sess=0 base=10 → floor 2", 0, map[string]any{"base_rpm": 10}, 2}, + {"conc=0 sess=0 base=1 → floor 1", 0, map[string]any{"base_rpm": 1}, 1}, + {"conc=0 sess=0 base=4 → floor 1", 0, map[string]any{"base_rpm": 4}, 1}, + {"conc=1 sess=0 base=15 → floor 3", 1, map[string]any{"base_rpm": 15}, 3}, + + // 手动 override + {"custom buffer=5", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5, "max_sessions": 10}, 5}, + {"custom buffer=0 fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0, "max_sessions": 10}, 13}, + {"custom buffer negative fallback", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1, "max_sessions": 10}, 13}, + {"custom buffer with float", 3, map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + + // 负值 clamp + {"negative concurrency clamped", -5, map[string]any{"base_rpm": 15, "max_sessions": 10}, 10}, + {"negative maxSessions clamped", 3, map[string]any{"base_rpm": 15, "max_sessions": -5}, 3}, + + // 高并发低会话 + {"conc=10 sess=5 → 15", 10, map[string]any{"base_rpm": 10, "max_sessions": 5}, 15}, + + // json.Number + {"json.Number base_rpm", 3, map[string]any{"base_rpm": json.Number("10"), "max_sessions": json.Number("5")}, 8}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &Account{Extra: tt.extra} + a := &Account{Concurrency: tt.concurrency, Extra: tt.extra} if got := a.GetRPMStickyBuffer(); got != tt.expected { t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 71d51712..3189a729 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -28,8 +28,7 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) - // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') - // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + // FindByExtraField 根据 extra 字段中的键值对查找账号 FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. @@ -174,6 +173,19 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( return nil, fmt.Errorf("create account: %w", err) } + // require_oauth_only 检查:apikey 类型账号不可加入限制分组 + if account.Type == AccountTypeAPIKey && len(req.GroupIDs) > 0 { + for _, gid := range req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if len(req.GroupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil { @@ -277,6 +289,19 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount return nil, fmt.Errorf("update account: %w", err) } + // require_oauth_only 检查 + if account.Type == AccountTypeAPIKey && req.GroupIDs != nil { + for _, gid := range *req.GroupIDs { + g, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + return nil, err + } + if g.RequireOAuthOnly && (g.Platform == PlatformOpenAI || g.Platform == PlatformAntigravity || g.Platform == PlatformAnthropic || g.Platform == PlatformGemini) { + return nil, fmt.Errorf("分组 [%s] 仅允许 OAuth 账号,apikey 类型账号无法加入", g.Name) + } + } + } + // 绑定分组 if req.GroupIDs != nil { if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index fec98e12..55865945 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -13,18 +13,14 @@ import ( "log" "net/http" "net/http/httptest" - "net/url" "regexp" "strings" - "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" - soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 - soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" - soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" - soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" - soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -71,13 +62,8 @@ type AccountTestService struct { httpUpstream HTTPUpstream cfg *config.Config tlsFPProfileService *TLSFingerprintProfileService - soraTestGuardMu sync.Mutex - soraTestLastRun map[int64]time.Time - soraTestCooldown time.Duration } -const defaultSoraTestCooldown = 10 * time.Second - // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -94,8 +80,6 @@ func NewAccountTestService( httpUpstream: httpUpstream, cfg: cfg, tlsFPProfileService: tlsFPProfileService, - soraTestLastRun: make(map[int64]time.Time), - soraTestCooldown: defaultSoraTestCooldown, } } @@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.routeAntigravityTest(c, account, modelID, prompt) } - if account.Platform == PlatformSora { - return s.testSoraAccountConnection(c, account) - } - return s.testClaudeAccountConnection(c, account, modelID) } @@ -551,6 +531,11 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account account.RateLimitResetAt = resetAt } } + // 401 Unauthorized: 标记账号为永久错误 + if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { + errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } @@ -629,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } -type soraProbeStep struct { - Name string `json:"name"` - Status string `json:"status"` - HTTPStatus int `json:"http_status,omitempty"` - ErrorCode string `json:"error_code,omitempty"` - Message string `json:"message,omitempty"` -} - -type soraProbeSummary struct { - Status string `json:"status"` - Steps []soraProbeStep `json:"steps"` -} - -type soraProbeRecorder struct { - steps []soraProbeStep -} - -func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { - r.steps = append(r.steps, soraProbeStep{ - Name: name, - Status: status, - HTTPStatus: httpStatus, - ErrorCode: strings.TrimSpace(errorCode), - Message: strings.TrimSpace(message), - }) -} - -func (r *soraProbeRecorder) finalize() soraProbeSummary { - meSuccess := false - partial := false - for _, step := range r.steps { - if step.Name == "me" { - meSuccess = strings.EqualFold(step.Status, "success") - continue - } - if strings.EqualFold(step.Status, "failed") { - partial = true - } - } - - status := "success" - if !meSuccess { - status = "failed" - } else if partial { - status = "partial_success" - } - - return soraProbeSummary{ - Status: status, - Steps: append([]soraProbeStep(nil), r.steps...), - } -} - -func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { - if rec == nil { - return - } - summary := rec.finalize() - code := "" - for _, step := range summary.Steps { - if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { - code = step.ErrorCode - break - } - } - s.sendEvent(c, TestEvent{ - Type: "sora_test_result", - Status: summary.Status, - Code: code, - Data: summary, - }) -} - -func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { - if accountID <= 0 { - return 0, true - } - s.soraTestGuardMu.Lock() - defer s.soraTestGuardMu.Unlock() - - if s.soraTestLastRun == nil { - s.soraTestLastRun = make(map[int64]time.Time) - } - cooldown := s.soraTestCooldown - if cooldown <= 0 { - cooldown = defaultSoraTestCooldown - } - - now := time.Now() - if lastRun, ok := s.soraTestLastRun[accountID]; ok { - elapsed := now.Sub(lastRun) - if elapsed < cooldown { - return cooldown - elapsed, false - } - } - s.soraTestLastRun[accountID] = now - return 0, true -} - -func ceilSeconds(d time.Duration) int { - if d <= 0 { - return 1 - } - sec := int(d / time.Second) - if d%time.Second != 0 { - sec++ - } - if sec < 1 { - sec = 1 - } - return sec -} - -// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 -// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 -func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { - ctx := c.Request.Context() - - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") - } - - baseURL := account.GetBaseURL() - if baseURL == "" { - return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") - } - - // 验证 base_url 格式 - normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) - } - upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" - - // 设置 SSE 头 - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { - msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) - return s.sendErrorAndEnd(c, msg) - } - - s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) - - // 构建轻量级 prompt-enhance 请求作为连通性测试 - testPayload := map[string]any{ - "model": "prompt-enhance-short-10s", - "messages": []map[string]string{{"role": "user", "content": "test"}}, - "stream": false, - } - payloadBytes, _ := json.Marshal(testPayload) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) - if err != nil { - return s.sendErrorAndEnd(c, "构建测试请求失败") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - // 获取代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) - - if resp.StatusCode == http.StatusOK { - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) - } - - // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 - if resp.StatusCode == http.StatusBadRequest { - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) -} - -// testSoraAccountConnection 测试 Sora 账号的连接 -// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 -// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 -func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { - // apikey 类型走独立测试流程 - if account.Type == AccountTypeAPIKey { - return s.testSoraAPIKeyAccountConnection(c, account) - } - - ctx := c.Request.Context() - recorder := &soraProbeRecorder{} - - authToken := account.GetCredential("access_token") - if authToken == "" { - recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "No access token available") - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { - msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) - recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, msg) - } - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) - - req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) - if err != nil { - recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Failed to create request") - } - - // 使用 Sora 客户端标准请求头 - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - soraTLSProfile := s.resolveSoraTLSProfile() - - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile) - if err != nil { - recorder.addStep("me", "failed", 0, "network_error", err.Error()) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { - recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") - s.emitSoraProbeSummary(c, recorder) - s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) - return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) - switch { - case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): - recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") - case strings.EqualFold(upstreamCode, "unsupported_country_code"): - recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") - case strings.TrimSpace(upstreamMessage) != "": - recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) - default: - recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) - } - } - recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") - - // 解析 /me 响应,提取用户信息 - var meResp map[string]any - if err := json.Unmarshal(body, &meResp); err != nil { - // 能收到 200 就说明 token 有效 - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) - } else { - // 尝试提取用户名或邮箱信息 - info := "Sora connection OK" - if name, ok := meResp["name"].(string); ok && name != "" { - info = fmt.Sprintf("Sora connection OK - User: %s", name) - } else if email, ok := meResp["email"].(string); ok && email != "" { - info = fmt.Sprintf("Sora connection OK - Email: %s", email) - } - s.sendEvent(c, TestEvent{Type: "content", Text: info}) - } - - // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) - subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) - if err == nil { - subReq.Header.Set("Authorization", "Bearer "+authToken) - subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - subReq.Header.Set("Accept", "application/json") - subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") - subReq.Header.Set("Origin", "https://sora.chatgpt.com") - subReq.Header.Set("Referer", "https://sora.chatgpt.com/") - - subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile) - if subErr != nil { - recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) - } else { - subBody, _ := io.ReadAll(subResp.Body) - _ = subResp.Body.Close() - if subResp.StatusCode == http.StatusOK { - recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") - if summary := parseSoraSubscriptionSummary(subBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) - } - } else { - if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { - recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") - s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) - } else { - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) - recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) - } - } - } - } - - // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 - s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder) - - s.emitSoraProbeSummary(c, recorder) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil -} - -func (s *AccountTestService) testSora2Capabilities( - c *gin.Context, - ctx context.Context, - account *Account, - authToken string, - proxyURL string, - tlsProfile *tlsfingerprint.Profile, - recorder *soraProbeRecorder, -) { - inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraInviteMineURL, - proxyURL, - tlsProfile, - ) - if err != nil { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) - return - } - - if inviteStatus == http.StatusUnauthorized { - bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraBootstrapURL, - proxyURL, - tlsProfile, - ) - if bootstrapErr == nil && bootstrapStatus == http.StatusOK { - if recorder != nil { - recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") - } - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) - inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraInviteMineURL, - proxyURL, - tlsProfile, - ) - if err != nil { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) - return - } - } else if recorder != nil { - code := "" - msg := "" - if bootstrapErr != nil { - code = "network_error" - msg = bootstrapErr.Error() - } - recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) - } - } - - if inviteStatus != http.StatusOK { - if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") - } - s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) - return - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) - if recorder != nil { - recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) - return - } - if recorder != nil { - recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") - } - - if summary := parseSoraInviteSummary(inviteBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) - } - - remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraRemainingURL, - proxyURL, - tlsProfile, - ) - if remainingErr != nil { - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) - return - } - if remainingStatus != http.StatusOK { - if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") - } - s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) - return - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) - return - } - if recorder != nil { - recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") - } - if summary := parseSoraRemainingSummary(remainingBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) - } -} - -func (s *AccountTestService) fetchSoraTestEndpoint( - ctx context.Context, - account *Account, - authToken string, - url string, - proxyURL string, - tlsProfile *tlsfingerprint.Profile, -) (int, http.Header, []byte, error) { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return 0, nil, nil, err - } - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile) - if err != nil { - return 0, nil, nil, err - } - defer func() { _ = resp.Body.Close() }() - - body, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return resp.StatusCode, resp.Header, nil, readErr - } - return resp.StatusCode, resp.Header, body, nil -} - -func parseSoraSubscriptionSummary(body []byte) string { - var subResp struct { - Data []struct { - Plan struct { - ID string `json:"id"` - Title string `json:"title"` - } `json:"plan"` - EndTS string `json:"end_ts"` - } `json:"data"` - } - if err := json.Unmarshal(body, &subResp); err != nil { - return "" - } - if len(subResp.Data) == 0 { - return "" - } - - first := subResp.Data[0] - parts := make([]string, 0, 3) - if first.Plan.Title != "" { - parts = append(parts, first.Plan.Title) - } - if first.Plan.ID != "" { - parts = append(parts, first.Plan.ID) - } - if first.EndTS != "" { - parts = append(parts, "end="+first.EndTS) - } - if len(parts) == 0 { - return "" - } - return "Subscription: " + strings.Join(parts, " | ") -} - -func parseSoraInviteSummary(body []byte) string { - var inviteResp struct { - InviteCode string `json:"invite_code"` - RedeemedCount int64 `json:"redeemed_count"` - TotalCount int64 `json:"total_count"` - } - if err := json.Unmarshal(body, &inviteResp); err != nil { - return "" - } - - parts := []string{"Sora2: supported"} - if inviteResp.InviteCode != "" { - parts = append(parts, "invite="+inviteResp.InviteCode) - } - if inviteResp.TotalCount > 0 { - parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) - } - return strings.Join(parts, " | ") -} - -func parseSoraRemainingSummary(body []byte) string { - var remainingResp struct { - RateLimitAndCreditBalance struct { - EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` - RateLimitReached bool `json:"rate_limit_reached"` - AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` - } `json:"rate_limit_and_credit_balance"` - } - if err := json.Unmarshal(body, &remainingResp); err != nil { - return "" - } - info := remainingResp.RateLimitAndCreditBalance - parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} - if info.RateLimitReached { - parts = append(parts, "rate_limited=true") - } - if info.AccessResetsInSeconds > 0 { - parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) - } - return strings.Join(parts, " | ") -} - -func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile { - if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint { - // Sora TLS fingerprint enabled — use built-in default profile - return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"} - } - return nil // disabled -} - -func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) -} - -func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { - return soraerror.FormatCloudflareChallengeMessage(base, headers, body) -} - -func extractCloudflareRayID(headers http.Header, body []byte) string { - return soraerror.ExtractCloudflareRayID(headers, body) -} - -func extractSoraEgressIPHint(headers http.Header) string { - if headers == nil { - return "unknown" - } - candidates := []string{ - "x-openai-public-ip", - "x-envoy-external-address", - "cf-connecting-ip", - "x-forwarded-for", - } - for _, key := range candidates { - if value := strings.TrimSpace(headers.Get(key)); value != "" { - return value - } - } - return "unknown" -} - -func sanitizeProxyURLForLog(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "" - } - u, err := url.Parse(raw) - if err != nil { - return "" - } - if u.User != nil { - u.User = nil - } - return u.String() -} - -func endpointPathForLog(endpoint string) string { - parsed, err := url.Parse(strings.TrimSpace(endpoint)) - if err != nil || parsed.Path == "" { - return endpoint - } - return parsed.Path -} - -func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { - accountID := int64(0) - platform := "" - proxyID := "none" - if account != nil { - accountID = account.ID - platform = account.Platform - if account.ProxyID != nil { - proxyID = fmt.Sprintf("%d", *account.ProxyID) - } - } - cfRay := extractCloudflareRayID(headers, body) - if cfRay == "" { - cfRay = "unknown" - } - log.Printf( - "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", - accountID, - platform, - endpoint, - endpointPathForLog(endpoint), - proxyID, - sanitizeProxyURLForLog(proxyURL), - cfRay, - extractSoraEgressIPHint(headers), - ) -} - -func truncateSoraErrorBody(body []byte, max int) string { - return soraerror.TruncateBody(body, max) -} - // routeAntigravityTest 路由 Antigravity 账号的测试请求。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go index 5ba04c69..f38264a2 100644 --- a/backend/internal/service/account_test_service_gemini_test.go +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) - ctx, recorder := newSoraTestContext() + ctx, recorder := newTestContext() svc := &AccountTestService{} stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index efa6f7da..5125db5b 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -4,16 +4,61 @@ package service import ( "context" + "fmt" "io" "net/http" + "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" ) +// --- shared test helpers --- + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, profile != nil) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +// --- test functions --- + +func newTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + type openAIAccountTestRepo struct { mockAccountRepoForGemini updatedExtra map[string]any @@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { gin.SetMode(gin.TestMode) - ctx, recorder := newSoraTestContext() + ctx, recorder := newTestContext() resp := newJSONResponse(http.StatusOK, "") resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} @@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { gin.SetMode(gin.TestMode) - ctx, _ := newSoraTestContext() + ctx, _ := newTestContext() resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) resp.Header.Set("x-codex-primary-used-percent", "100") diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go deleted file mode 100644 index 52f832a9..00000000 --- a/backend/internal/service/account_test_service_sora_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package service - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -type queuedHTTPUpstream struct { - responses []*http.Response - requests []*http.Request - tlsFlags []bool -} - -func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - return nil, fmt.Errorf("unexpected Do call") -} - -func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) { - u.requests = append(u.requests, req) - u.tlsFlags = append(u.tlsFlags, profile != nil) - if len(u.responses) == 0 { - return nil, fmt.Errorf("no mocked response") - } - resp := u.responses[0] - u.responses = u.responses[1:] - return resp, nil -} - -func newJSONResponse(status int, body string) *http.Response { - return &http.Response{ - StatusCode: status, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader(body)), - } -} - -func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { - resp := newJSONResponse(status, body) - resp.Header.Set(key, value) - return resp -} - -func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { - gin.SetMode(gin.TestMode) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) - return c, rec -} - -func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), - newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), - newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), - newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), - }, - } - svc := &AccountTestService{ - httpUpstream: upstream, - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - TLSFingerprint: config.TLSFingerprintConfig{ - Enabled: true, - }, - }, - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - DisableTLSFingerprint: false, - }, - }, - }, - } - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - require.Len(t, upstream.requests, 4) - require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) - require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) - require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) - require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) - require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) - require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) - require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) - - body := rec.Body.String() - require.Contains(t, body, `"type":"test_start"`) - require.Contains(t, body, "Sora connection OK - Email: demo@example.com") - require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") - require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") - require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"success"`) - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), - newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), - newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), - newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - require.Len(t, upstream.requests, 4) - body := rec.Body.String() - require.Contains(t, body, "Sora connection OK - User: demo-user") - require.Contains(t, body, "Subscription check returned 403") - require.Contains(t, body, "Sora2 invite check returned 401") - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"partial_success"`) - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "Cloudflare challenge") - require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") - body := rec.Body.String() - require.Contains(t, body, `"type":"error"`) - require.Contains(t, body, "Cloudflare challenge") - require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") -} - -func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "Cloudflare challenge") - require.Contains(t, err.Error(), "HTTP 429") - body := rec.Body.String() - require.Contains(t, body, "Cloudflare challenge") -} - -func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "token_invalidated") - body := rec.Body.String() - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"failed"`) - require.Contains(t, body, "token_invalidated") - require.NotContains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), - }, - } - svc := &AccountTestService{ - httpUpstream: upstream, - soraTestCooldown: time.Hour, - } - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c1, _ := newSoraTestContext() - err := svc.testSoraAccountConnection(c1, account) - require.NoError(t, err) - - c2, rec2 := newSoraTestContext() - err = svc.testSoraAccountConnection(c2, account) - require.Error(t, err) - require.Contains(t, err.Error(), "测试过于频繁") - body := rec2.Body.String() - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"code":"test_rate_limited"`) - require.Contains(t, body, `"status":"failed"`) - require.NotContains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), - newJSONResponse(http.StatusForbidden, `Just a moment...`), - newJSONResponse(http.StatusForbidden, `Just a moment...`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - body := rec.Body.String() - require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") - require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") - require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestSanitizeProxyURLForLog(t *testing.T) { - require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) - require.Equal(t, "", sanitizeProxyURLForLog("")) - require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) -} - -func TestExtractSoraEgressIPHint(t *testing.T) { - h := make(http.Header) - h.Set("x-openai-public-ip", "203.0.113.10") - require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) - - h2 := make(http.Header) - h2.Set("x-envoy-external-address", "198.51.100.9") - require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) - - require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) - require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) -} diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 0d7ffffa..d903b940 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected bool @@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { requestedModel: "claude-opus-4-5-thinking", expected: true, }, + { + name: "gemini customtools alias matches normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: true, + }, { name: "wildcard match not supported", credentials: map[string]any{ @@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.IsModelSupported(tt.requestedModel) @@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected string @@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", }, + { + name: "no mapping preserves gemini customtools model", + platform: PlatformGemini, + credentials: nil, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, // 精确匹配 { @@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { }, // 无匹配返回原始模型 + { + name: "gemini customtools alias resolves through normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview", + }, + { + name: "gemini customtools exact mapping wins over normalized fallback", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, { name: "no match returns original", credentials: map[string]any{ @@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.GetMappedModel(tt.requestedModel) @@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expectedModel string @@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { expectedModel: "gpt-5.4", expectedMatch: true, }, + { + name: "gemini customtools alias reports normalized match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview", + expectedMatch: true, + }, + { + name: "gemini customtools exact mapping reports exact match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview-customtools", + expectedMatch: true, + }, { name: "missing mapping reports unmatched", credentials: map[string]any{ @@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 88c064f3..8032f871 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -14,7 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + "github.com/Wei-Shaw/sub2api/internal/util/httputil" ) // AdminService interface defines admin management operations @@ -103,14 +104,13 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 - SoraStorageQuotaBytes int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 } type UpdateUserInput struct { @@ -124,8 +124,7 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 - SoraStorageQuotaBytes *int64 + GroupRates map[int64]*float64 } type CreateGroupInput struct { @@ -139,16 +138,11 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - // Sora 按次计费配置 - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -157,11 +151,11 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string - // Sora 存储配额 - SoraStorageQuotaBytes int64 // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -178,16 +172,11 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - // Sora 按次计费配置 - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -196,11 +185,11 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string - // Sora 存储配额 - SoraStorageQuotaBytes *int64 // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -421,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{ http.StatusOK: {}, }, }, - { - Target: "sora", - URL: "https://sora.chatgpt.com/backend/me", - Method: http.MethodGet, - AllowedStatuses: map[int]struct{}{ - http.StatusUnauthorized: {}, - }, - }, } const ( @@ -443,7 +424,6 @@ type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -468,7 +448,6 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, - soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -487,7 +466,6 @@ func NewAdminService( userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, - soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -569,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, - SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -649,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } - if input.SoraStorageQuotaBytes != nil { - user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes - } - if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -855,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) - soraImagePrice360 := normalizePrice(input.SoraImagePrice360) - soraImagePrice540 := normalizePrice(input.SoraImagePrice540) - soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) - soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -929,24 +898,42 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, - SoraImagePrice360: soraImagePrice360, - SoraImagePrice540: soraImagePrice540, - SoraVideoPricePerRequest: soraVideoPrice, - SoraVideoPricePerRequestHD: soraVideoPriceHD, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, - SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, AllowMessagesDispatch: input.AllowMessagesDispatch, + RequireOAuthOnly: input.RequireOAuthOnly, + RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 如果有需要复制的账号,绑定到新分组 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { @@ -1087,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } - if input.SoraImagePrice360 != nil { - group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) - } - if input.SoraImagePrice540 != nil { - group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) - } - if input.SoraVideoPricePerRequest != nil { - group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) - } - if input.SoraVideoPricePerRequestHD != nil { - group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) - } - if input.SoraStorageQuotaBytes != nil { - group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes - } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -1154,6 +1126,12 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.AllowMessagesDispatch != nil { group.AllowMessagesDispatch = *input.AllowMessagesDispatch } + if input.RequireOAuthOnly != nil { + group.RequireOAuthOnly = *input.RequireOAuthOnly + } + if input.RequirePrivacySet != nil { + group.RequirePrivacySet = *input.RequirePrivacySet + } if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } @@ -1201,6 +1179,27 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) } + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + // 再绑定源分组的账号 if len(accountIDsToCopy) > 0 { if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { @@ -1511,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } - // Sora apikey 账号的 base_url 必填校验 - if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { - baseURL, _ := input.Credentials["base_url"].(string) - baseURL = strings.TrimSpace(baseURL) - if baseURL == "" { - return nil, errors.New("sora apikey 账号必须设置 base_url") - } - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") - } - } - account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1568,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } - // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 - if account.Platform == PlatformSora && s.soraAccountRepo != nil { - soraUpdates := map[string]any{ - "access_token": account.GetCredential("access_token"), - "refresh_token": account.GetCredential("refresh_token"), - } - if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { - // 只记录警告日志,不阻塞账号创建 - logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) - } - } - // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { @@ -1587,6 +1562,31 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // OAuth 账号:创建后异步设置隐私。 + // 使用 Ensure(幂等)而非 Force:新建账号 Extra 为空时效果相同,但更安全。 + if account.Type == AccountTypeOAuth { + switch account.Platform { + case PlatformOpenAI: + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("create_account_openai_privacy_panic", "account_id", account.ID, "recover", r) + } + }() + s.EnsureOpenAIPrivacy(context.Background(), account) + }() + case PlatformAntigravity: + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r) + } + }() + s.EnsureAntigravityPrivacy(context.Background(), account) + }() + } + } + return account, nil } @@ -1683,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } - // Sora apikey 账号的 base_url 必填校验 - if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { - baseURL, _ := account.Credentials["base_url"].(string) - baseURL = strings.TrimSpace(baseURL) - if baseURL == "" { - return nil, errors.New("sora apikey 账号必须设置 base_url") - } - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") - } - } - // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { @@ -2297,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox body = body[:proxyQualityMaxBodyBytes] } - if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + // Cloudflare challenge 检测 + if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { item.Status = "challenge" - item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) - item.Message = "Sora 命中 Cloudflare challenge" + item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body) + item.Message = "命中 Cloudflare challenge" return item } @@ -2716,16 +2705,14 @@ func (s *adminServiceImpl) ForceOpenAIPrivacy(ctx context.Context, account *Acco } // EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号隐私状态。 -// 如果 Extra["privacy_mode"] 已存在(无论成功或失败),直接跳过。 -// 仅对从未设置过隐私的账号执行 setUserSettings + fetchUserInfo 流程。 -// 用户可通过前端 ForceAntigravityPrivacy(SetPrivacy 按钮)强制重新设置。 +// 仅当 privacy_mode 已成功设置("privacy_set")时跳过; +// 未设置或之前失败("privacy_set_failed")均会重试。 func (s *adminServiceImpl) EnsureAntigravityPrivacy(ctx context.Context, account *Account) string { if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { return "" } - // 已设置过则跳过(无论成功或失败),用户可通过 Force 手动重试 if account.Extra != nil { - if existing, ok := account.Extra["privacy_mode"].(string); ok && existing != "" { + if existing, ok := account.Extra["privacy_mode"].(string); ok && existing == AntigravityPrivacySet { return existing } } diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go index 5a43cd9c..d3b3f61b 100644 --- a/backend/internal/service/admin_service_proxy_quality_test.go +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { require.Contains(t, result.Summary, "挑战 1 项") } -func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { +func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") w.Header().Set("cf-ray", "test-ray-123") @@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { defer server.Close() target := proxyQualityTarget{ - Target: "sora", + Target: "openai", URL: server.URL, Method: http.MethodGet, AllowedStatuses: map[int]struct{}{ diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 1dbe9870..a29000e7 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { requestedModel: "gemini-2.5-flash", expected: "gemini-2.5-flash", }, + { + name: "customtools alias falls back to normalized preview mapping", + modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"}, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-high", + }, } for _, tt := range tests { diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index a300d898..3a4600db 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -91,6 +91,7 @@ type AntigravityTokenInfo struct { ProjectID string `json:"project_id,omitempty"` ProjectIDMissing bool `json:"-"` PlanType string `json:"-"` + PrivacyMode string `json:"-"` } // ExchangeCode 用 authorization code 交换 token @@ -159,6 +160,9 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } + // 令牌刚获取,立即设置隐私(不依赖后续账号创建流程) + result.PrivacyMode = setAntigravityPrivacy(ctx, result.AccessToken, result.ProjectID, proxyURL) + return result, nil } @@ -248,6 +252,9 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } } + // 令牌刚获取,立即设置隐私 + tokenInfo.PrivacyMode = setAntigravityPrivacy(ctx, tokenInfo.AccessToken, tokenInfo.ProjectID, proxyURL) + return tokenInfo, nil } diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index ecaffcbc..e3b60a27 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -5,13 +5,12 @@ package service import ( "bytes" "context" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/stretchr/testify/require" "io" "net/http" "strings" "testing" - - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/stretchr/testify/require" ) // stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock @@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI m.responseBodies[respIdx] = bodyBytes } - // 用缓存的 body 字节重建新的 reader - var body io.ReadCloser + // 用缓存的 body 重建 reader(支持重试场景多次读取) + cloned := *resp if m.responseBodies[respIdx] != nil { - body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx])) + cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx])) } - - return &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: body, - }, respErr + return &cloned, respErr } func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index e8ad5c9c..ad6ba0e9 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f727ab10..64a70e8c 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, - SoraImagePrice360: apiKey.Group.SoraImagePrice360, - SoraImagePrice540: apiKey.Group.SoraImagePrice540, - SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, @@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, - SoraImagePrice360: snapshot.Group.SoraImagePrice360, - SoraImagePrice540: snapshot.Group.SoraImagePrice540, - SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 004511f5..763abadb 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -56,6 +56,7 @@ type ModelPricing struct { LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 + ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD) } const ( @@ -94,16 +95,19 @@ type UsageTokens struct { CacheReadTokens int CacheCreation5mTokens int CacheCreation1hTokens int + ImageOutputTokens int } // CostBreakdown 费用明细 type CostBreakdown struct { InputCost float64 OutputCost float64 + ImageOutputCost float64 CacheCreationCost float64 CacheReadCost float64 TotalCost float64 ActualCost float64 // 应用倍率后的实际费用 + BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充 } // BillingService 计费服务 @@ -357,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken, }), nil } } @@ -371,81 +376,252 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { return nil, fmt.Errorf("pricing not found for model: %s", model) } -// CalculateCost 计算使用费用 -func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { - return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") -} - -func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { +// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值 +// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价 +func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err } + if channelPricing == nil { + return pricing, nil + } + if channelPricing.InputPrice != nil { + pricing.InputPricePerToken = *channelPricing.InputPrice + pricing.InputPricePerTokenPriority = *channelPricing.InputPrice + } + if channelPricing.OutputPrice != nil { + pricing.OutputPricePerToken = *channelPricing.OutputPrice + pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice + } + if channelPricing.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice + pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice + pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice + } + if channelPricing.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice + } + if channelPricing.ImageOutputPrice != nil { + pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice + } + return pricing, nil +} - breakdown := &CostBreakdown{} - inputPricePerToken := pricing.InputPricePerToken - outputPricePerToken := pricing.OutputPricePerToken - cacheReadPricePerToken := pricing.CacheReadPricePerToken +// --- 统一计费入口 --- + +// CostInput 统一计费输入 +type CostInput struct { + Ctx context.Context + Model string + GroupID *int64 // 用于渠道定价查找 + Tokens UsageTokens + RequestCount int // 按次计费时使用 + SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) + RateMultiplier float64 + ServiceTier string // "priority","flex","" 等 + Resolver *ModelPricingResolver // 定价解析器 + Resolved *ResolvedPricing // 可选:预解析的定价结果(避免重复 Resolve 调用) +} + +// CalculateCostUnified 统一计费入口,支持三种计费模式。 +// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。 +func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) { + if input.Resolver == nil { + // 无 Resolver,回退到旧路径 + return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil) + } + + // 优先使用预解析结果,避免重复 Resolve 调用 + resolved := input.Resolved + if resolved == nil { + resolved = input.Resolver.Resolve(input.Ctx, PricingInput{ + Model: input.Model, + GroupID: input.GroupID, + }) + } + + if input.RateMultiplier <= 0 { + input.RateMultiplier = 1.0 + } + + var breakdown *CostBreakdown + var err error + switch resolved.Mode { + case BillingModePerRequest, BillingModeImage: + breakdown, err = s.calculatePerRequestCost(resolved, input) + default: // BillingModeToken + breakdown, err = s.calculateTokenCost(resolved, input) + } + if err == nil && breakdown != nil { + breakdown.BillingMode = string(resolved.Mode) + if breakdown.BillingMode == "" { + breakdown.BillingMode = string(BillingModeToken) + } + } + return breakdown, err +} + +// calculateTokenCost 按 token 区间计费 +func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + + pricing := input.Resolver.GetIntervalPricing(resolved, totalContext) + if pricing == nil { + return nil, fmt.Errorf("no pricing available for model: %s", input.Model) + } + + pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing) + + // 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层) + applyLongCtx := len(resolved.Intervals) == 0 + + return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil +} + +// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。 +// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。 +func (s *BillingService) computeTokenBreakdown( + pricing *ModelPricing, tokens UsageTokens, + rateMultiplier float64, serviceTier string, + applyLongCtx bool, +) *CostBreakdown { + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + + inputPrice := pricing.InputPricePerToken + outputPrice := pricing.OutputPricePerToken + cacheReadPrice := pricing.CacheReadPricePerToken tierMultiplier := 1.0 + if usePriorityServiceTierPricing(serviceTier, pricing) { if pricing.InputPricePerTokenPriority > 0 { - inputPricePerToken = pricing.InputPricePerTokenPriority + inputPrice = pricing.InputPricePerTokenPriority } if pricing.OutputPricePerTokenPriority > 0 { - outputPricePerToken = pricing.OutputPricePerTokenPriority + outputPrice = pricing.OutputPricePerTokenPriority } if pricing.CacheReadPricePerTokenPriority > 0 { - cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + cacheReadPrice = pricing.CacheReadPricePerTokenPriority } } else { tierMultiplier = serviceTierCostMultiplier(serviceTier) } - if s.shouldApplySessionLongContextPricing(tokens, pricing) { - inputPricePerToken *= pricing.LongContextInputMultiplier - outputPricePerToken *= pricing.LongContextOutputMultiplier + + if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) { + inputPrice *= pricing.LongContextInputMultiplier + outputPrice *= pricing.LongContextOutputMultiplier } - // 计算输入token费用(使用per-token价格) - breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken + bd := &CostBreakdown{} + bd.InputCost = float64(tokens.InputTokens) * inputPrice - // 计算输出token费用 - breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken + // 分离图片输出 token 与文本输出 token + textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens + if textOutputTokens < 0 { + textOutputTokens = 0 + } + bd.OutputCost = float64(textOutputTokens) * outputPrice - // 计算缓存费用 - if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) - if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { - // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 - breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice - } else { - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + // 图片输出 token 费用(独立费率) + if tokens.ImageOutputTokens > 0 { + imgPrice := pricing.ImageOutputPricePerToken + if imgPrice == 0 { + imgPrice = outputPrice // 回退到常规输出价格 } - } else { - // 标准缓存创建价格(per-token) - breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken + bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice } - breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken + // 缓存创建费用 + bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens) + + bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice if tierMultiplier != 1.0 { - breakdown.InputCost *= tierMultiplier - breakdown.OutputCost *= tierMultiplier - breakdown.CacheCreationCost *= tierMultiplier - breakdown.CacheReadCost *= tierMultiplier + bd.InputCost *= tierMultiplier + bd.OutputCost *= tierMultiplier + bd.ImageOutputCost *= tierMultiplier + bd.CacheCreationCost *= tierMultiplier + bd.CacheReadCost *= tierMultiplier } - // 计算总费用 - breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + - breakdown.CacheCreationCost + breakdown.CacheReadCost + bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost + + bd.CacheCreationCost + bd.CacheReadCost + bd.ActualCost = bd.TotalCost * rateMultiplier - // 应用倍率计算实际费用 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + return bd +} + +// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。 +func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 { + if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { + // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 + return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } + return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice } - breakdown.ActualCost = breakdown.TotalCost * rateMultiplier + return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken +} - return breakdown, nil +// calculatePerRequestCost 按次/图片计费 +func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + count := input.RequestCount + if count <= 0 { + count = 1 + } + + var unitPrice float64 + + if input.SizeTier != "" { + unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier) + } + + if unitPrice == 0 { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext) + } + + // 回退到默认按次价格 + if unitPrice == 0 { + unitPrice = resolved.DefaultPerRequestPrice + } + + totalCost := unitPrice * float64(count) + actualCost := totalCost * input.RateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + }, nil +} + +// CalculateCost 计算使用费用 +func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { + return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil) +} + +func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { + return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil) +} + +func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) { + var pricing *ModelPricing + var err error + if channelPricing != nil { + pricing, err = s.GetModelPricingWithChannel(model, channelPricing) + } else { + pricing, err = s.GetModelPricing(model) + } + if err != nil { + return nil, err + } + + // 旧路径始终检查长上下文定价(无区间定价概念) + return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil } func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { @@ -541,6 +717,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage CacheReadTokens: inRangeCacheTokens, CacheCreation5mTokens: tokens.CacheCreation5mTokens, CacheCreation1hTokens: tokens.CacheCreation1hTokens, + ImageOutputTokens: tokens.ImageOutputTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { @@ -561,6 +738,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage return &CostBreakdown{ InputCost: inRangeCost.InputCost + outRangeCost.InputCost, OutputCost: inRangeCost.OutputCost, + ImageOutputCost: inRangeCost.ImageOutputCost, CacheCreationCost: inRangeCost.CacheCreationCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, @@ -630,14 +808,6 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } -// SoraPriceConfig Sora 按次计费配置 -type SoraPriceConfig struct { - ImagePrice360 *float64 - ImagePrice540 *float64 - VideoPricePerRequest *float64 - VideoPricePerRequestHD *float64 -} - // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -662,67 +832,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag actualCost := totalCost * rateMultiplier return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, - } -} - -// CalculateSoraImageCost 计算 Sora 图片按次费用 -func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { - if imageCount <= 0 { - return &CostBreakdown{} - } - - unitPrice := 0.0 - if groupConfig != nil { - switch imageSize { - case "540": - if groupConfig.ImagePrice540 != nil { - unitPrice = *groupConfig.ImagePrice540 - } - default: - if groupConfig.ImagePrice360 != nil { - unitPrice = *groupConfig.ImagePrice360 - } - } - } - - totalCost := unitPrice * float64(imageCount) - if rateMultiplier <= 0 { - rateMultiplier = 1.0 - } - actualCost := totalCost * rateMultiplier - - return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, - } -} - -// CalculateSoraVideoCost 计算 Sora 视频按次费用 -func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { - unitPrice := 0.0 - if groupConfig != nil { - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "sora2pro-hd") { - if groupConfig.VideoPricePerRequestHD != nil { - unitPrice = *groupConfig.VideoPricePerRequestHD - } - } - if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { - unitPrice = *groupConfig.VideoPricePerRequest - } - } - - totalCost := unitPrice - if rateMultiplier <= 0 { - rateMultiplier = 1.0 - } - actualCost := totalCost * rateMultiplier - - return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, + TotalCost: totalCost, + ActualCost: actualCost, + BillingMode: string(BillingModeImage), } } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 10943422..dd58502c 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) { require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) } -func TestCalculateSoraVideoCost(t *testing.T) { - svc := newTestBillingService() - - price := 0.5 - cfg := &SoraPriceConfig{VideoPricePerRequest: &price} - cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) - - require.InDelta(t, 0.5, cost.TotalCost, 1e-10) -} - -func TestCalculateSoraVideoCost_HDModel(t *testing.T) { - svc := newTestBillingService() - - hdPrice := 1.0 - normalPrice := 0.5 - cfg := &SoraPriceConfig{ - VideoPricePerRequest: &normalPrice, - VideoPricePerRequestHD: &hdPrice, - } - cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) - require.InDelta(t, 1.0, cost.TotalCost, 1e-10) -} func TestIsModelSupported(t *testing.T) { svc := newTestBillingService() @@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) { require.Contains(t, err.Error(), "not initialized") } -func TestCalculateSoraImageCost(t *testing.T) { - svc := newTestBillingService() - - price360 := 0.05 - price540 := 0.08 - cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540} - - cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0) - require.InDelta(t, 0.10, cost.TotalCost, 1e-10) - - cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0) - require.InDelta(t, 0.08, cost540.TotalCost, 1e-10) - require.InDelta(t, 0.16, cost540.ActualCost, 1e-10) -} - -func TestCalculateSoraImageCost_ZeroCount(t *testing.T) { - svc := newTestBillingService() - cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0) - require.Equal(t, 0.0, cost.TotalCost) -} - -func TestCalculateSoraVideoCost_NilConfig(t *testing.T) { - svc := newTestBillingService() - cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0) - require.Equal(t, 0.0, cost.TotalCost) -} - func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) { // 使用空的 fallback prices 让 GetModelPricing 失败 svc := &BillingService{ diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go new file mode 100644 index 00000000..1697ed6f --- /dev/null +++ b/backend/internal/service/channel.go @@ -0,0 +1,277 @@ +package service + +import ( + "fmt" + "sort" + "strings" + "time" +) + +// BillingMode 计费模式 +type BillingMode string + +const ( + BillingModeToken BillingMode = "token" // 按 token 区间计费 + BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) + BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) +) + +// IsValid 检查 BillingMode 是否为合法值 +func (m BillingMode) IsValid() bool { + switch m { + case BillingModeToken, BillingModePerRequest, BillingModeImage, "": + return true + } + return false +} + +const ( + BillingModelSourceRequested = "requested" + BillingModelSourceUpstream = "upstream" + BillingModelSourceChannelMapped = "channel_mapped" +) + +// Channel 渠道实体 +type Channel struct { + ID int64 + Name string + Description string + Status string + BillingModelSource string // "requested", "upstream", or "channel_mapped" + RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + CreatedAt time.Time + UpdatedAt time.Time + + // 关联的分组 ID 列表 + GroupIDs []int64 + // 模型定价列表(每条含 Platform 字段) + ModelPricing []ChannelModelPricing + // 渠道级模型映射(按平台分组:platform → {src→dst}) + ModelMapping map[string]map[string]string +} + +// ChannelModelPricing 渠道模型定价条目 +type ChannelModelPricing struct { + ID int64 + ChannelID int64 + Platform string // 所属平台(anthropic/openai/gemini/...) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) + PerRequestPrice *float64 // 默认按次计费价格(USD) + Intervals []PricingInterval // 区间定价列表 + CreatedAt time.Time + UpdatedAt time.Time +} + +// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层) +type PricingInterval struct { + ID int64 + PricingID int64 + MinTokens int // 区间下界(含) + MaxTokens *int // 区间上界(不含),nil = 无上限 + TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等) + InputPrice *float64 // token 模式:每 token 输入价 + OutputPrice *float64 // token 模式:每 token 输出价 + CacheWritePrice *float64 // token 模式:缓存写入价 + CacheReadPrice *float64 // token 模式:缓存读取价 + PerRequestPrice *float64 // 按次/图片模式:每次请求价格 + SortOrder int + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsActive 判断渠道是否启用 +func (c *Channel) IsActive() bool { + return c.Status == StatusActive +} + +// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 +// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。 +func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { + modelLower := strings.ToLower(model) + + for i := range c.ModelPricing { + for _, m := range c.ModelPricing[i].Models { + if strings.ToLower(m) == modelLower { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + + return nil +} + +// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。 +// 区间为左开右闭 (min, max]:min 不含,max 包含。 +// 第一个区间 min=0 时,0 token 不匹配任何区间(回退到默认价格)。 +func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval { + for i := range intervals { + iv := &intervals[i] + if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) { + return iv + } + } + return nil +} + +// GetIntervalForContext 根据总 context token 数查找匹配的区间。 +func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval { + return FindMatchingInterval(p.Intervals, totalTokens) +} + +// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式) +func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval { + labelLower := strings.ToLower(label) + for i := range p.Intervals { + if strings.ToLower(p.Intervals[i].TierLabel) == labelLower { + return &p.Intervals[i] + } + } + return nil +} + +// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全) +func (p ChannelModelPricing) Clone() ChannelModelPricing { + cp := p + if p.Models != nil { + cp.Models = make([]string, len(p.Models)) + copy(cp.Models, p.Models) + } + if p.Intervals != nil { + cp.Intervals = make([]PricingInterval, len(p.Intervals)) + copy(cp.Intervals, p.Intervals) + } + return cp +} + +// Clone 返回 Channel 的深拷贝 +func (c *Channel) Clone() *Channel { + if c == nil { + return nil + } + cp := *c + if c.GroupIDs != nil { + cp.GroupIDs = make([]int64, len(c.GroupIDs)) + copy(cp.GroupIDs, c.GroupIDs) + } + if c.ModelPricing != nil { + cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing)) + for i := range c.ModelPricing { + cp.ModelPricing[i] = c.ModelPricing[i].Clone() + } + } + if c.ModelMapping != nil { + cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping)) + for platform, mapping := range c.ModelMapping { + inner := make(map[string]string, len(mapping)) + for k, v := range mapping { + inner[k] = v + } + cp.ModelMapping[platform] = inner + } + } + return &cp +} + +// ValidateIntervals 校验区间列表的合法性。 +// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; +// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); +// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。 +func ValidateIntervals(intervals []PricingInterval) error { + if len(intervals) == 0 { + return nil + } + sorted := make([]PricingInterval, len(intervals)) + copy(sorted, intervals) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].MinTokens < sorted[j].MinTokens + }) + + for i := range sorted { + if err := validateSingleInterval(&sorted[i], i); err != nil { + return err + } + } + return validateIntervalOverlap(sorted) +} + +// validateSingleInterval 校验单个区间的字段合法性 +func validateSingleInterval(iv *PricingInterval, idx int) error { + if iv.MinTokens < 0 { + return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens) + } + if iv.MaxTokens != nil { + if *iv.MaxTokens <= 0 { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens) + } + if *iv.MaxTokens <= iv.MinTokens { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)", + idx+1, *iv.MaxTokens, iv.MinTokens) + } + } + return validateIntervalPrices(iv, idx) +} + +// validateIntervalPrices 校验区间内所有价格字段 >= 0 +func validateIntervalPrices(iv *PricingInterval, idx int) error { + prices := []struct { + name string + val *float64 + }{ + {"input_price", iv.InputPrice}, + {"output_price", iv.OutputPrice}, + {"cache_write_price", iv.CacheWritePrice}, + {"cache_read_price", iv.CacheReadPrice}, + {"per_request_price", iv.PerRequestPrice}, + } + for _, p := range prices { + if p.val != nil && *p.val < 0 { + return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name) + } + } + return nil +} + +// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后 +func validateIntervalOverlap(sorted []PricingInterval) error { + for i, iv := range sorted { + // 无界区间必须是最后一个 + if iv.MaxTokens == nil && i < len(sorted)-1 { + return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one", + i+1) + } + if i == 0 { + continue + } + prev := sorted[i-1] + // 检查重叠:前一个区间的上界 > 当前区间的下界则重叠 + // (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max] + if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens { + return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d", + i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens) + } + } + return nil +} + +func formatMaxTokensLabel(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + +// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) +type ChannelUsageFields struct { + ChannelID int64 // 渠道 ID(0 = 无渠道) + OriginalModel string // 用户原始请求模型(渠道映射前) + ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel) + BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" + ModelMappingChain string // 映射链描述,如 "a→b→c" +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go new file mode 100644 index 00000000..7b96084d --- /dev/null +++ b/backend/internal/service/channel_service.go @@ -0,0 +1,842 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" +) + +var ( + ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found") + ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists") + ErrGroupAlreadyInChannel = infraerrors.Conflict( + "GROUP_ALREADY_IN_CHANNEL", + "one or more groups already belong to another channel", + ) +) + +// ChannelRepository 渠道数据访问接口 +type ChannelRepository interface { + Create(ctx context.Context, channel *Channel) error + GetByID(ctx context.Context, id int64) (*Channel, error) + Update(ctx context.Context, channel *Channel) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + ListAll(ctx context.Context) ([]Channel, error) + ExistsByName(ctx context.Context, name string) (bool, error) + ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) + + // 分组关联 + GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) + SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error + GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) + GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + + // 分组平台查询 + GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) + + // 模型定价 + ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + DeleteModelPricing(ctx context.Context, id int64) error + ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突) +type channelModelKey struct { + groupID int64 + platform string // 平台标识 + model string // lowercase +} + +// channelGroupPlatformKey 通配符定价缓存键 +type channelGroupPlatformKey struct { + groupID int64 + platform string +} + +// wildcardPricingEntry 通配符定价条目 +type wildcardPricingEntry struct { + prefix string + pricing *ChannelModelPricing +} + +// wildcardMappingEntry 通配符映射条目 +type wildcardMappingEntry struct { + prefix string + target string +} + +// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) +type channelCache struct { + // 热路径查找 + pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 + wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序) + mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 + wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序) + channelByGroupID map[int64]*Channel // groupID → 渠道 + groupPlatform map[int64]string // groupID → platform + + // 冷路径(CRUD 操作) + byID map[int64]*Channel + loadedAt time.Time +} + +// ChannelMappingResult 渠道映射查找结果 +type ChannelMappingResult struct { + MappedModel string // 映射后的模型名(无映射时等于原始模型名) + ChannelID int64 // 渠道 ID(0 = 无渠道关联) + Mapped bool // 是否发生了映射 + BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped") +} + +// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。 +// reqModel: 客户端请求的原始模型名。 +// upstreamModel: 上游实际使用的模型名(ForwardResult.UpstreamModel)。 +// 返回空字符串表示无映射。 +func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string { + if !r.Mapped { + if upstreamModel != "" && upstreamModel != reqModel { + return reqModel + "→" + upstreamModel + } + return "" + } + if upstreamModel != "" && upstreamModel != r.MappedModel { + return reqModel + "→" + r.MappedModel + "→" + upstreamModel + } + return reqModel + "→" + r.MappedModel +} + +// ToUsageFields 将渠道映射结果转为使用记录字段 +func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields { + channelMappedModel := reqModel + if r.Mapped { + channelMappedModel = r.MappedModel + } + return ChannelUsageFields{ + ChannelID: r.ChannelID, + OriginalModel: reqModel, + ChannelMappedModel: channelMappedModel, + BillingModelSource: r.BillingModelSource, + ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel), + } +} + +const ( + channelCacheTTL = 10 * time.Minute + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelCacheDBTimeout = 10 * time.Second +) + +// ChannelService 渠道管理服务 +type ChannelService struct { + repo ChannelRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + + cache atomic.Value // *channelCache + cacheSF singleflight.Group +} + +// NewChannelService 创建渠道服务实例 +func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { + s := &ChannelService{ + repo: repo, + authCacheInvalidator: authCacheInvalidator, + } + return s +} + +// loadCache 加载或返回缓存的渠道数据 +func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) { + if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + + result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) { + // 双重检查 + if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + return s.buildCache(ctx) + }) + if err != nil { + return nil, err + } + cache, ok := result.(*channelCache) + if !ok { + return nil, fmt.Errorf("unexpected cache type") + } + return cache, nil +} + +// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化) +func newEmptyChannelCache() *channelCache { + return &channelCache{ + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), + mappingByGroupModel: make(map[channelModelKey]string), + wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: make(map[int64]string), + byID: make(map[int64]*Channel), + } +} + +// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 +// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。 +// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。 +func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for j := range ch.ModelPricing { + pricing := &ch.ModelPricing[j] + if !isPlatformPricingMatch(platform, pricing.Platform) { + continue // 跳过非本平台的定价 + } + // 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突 + pricingPlatform := pricing.Platform + gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform} + for _, model := range pricing.Models { + if strings.HasSuffix(model, "*") { + prefix := strings.ToLower(strings.TrimSuffix(model, "*")) + cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{ + prefix: prefix, + pricing: pricing, + }) + } else { + key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)} + cache.pricingByGroupModel[key] = pricing + } + } + } +} + +// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 +// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。 +func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for _, mappingPlatform := range matchingPlatforms(platform) { + platformMapping, ok := ch.ModelMapping[mappingPlatform] + if !ok { + continue + } + // 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突 + gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform} + for src, dst := range platformMapping { + if strings.HasSuffix(src, "*") { + prefix := strings.ToLower(strings.TrimSuffix(src, "*")) + cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{ + prefix: prefix, + target: dst, + }) + } else { + key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)} + cache.mappingByGroupModel[key] = dst + } + } + } +} + +// buildCache 从数据库构建渠道缓存。 +// 使用独立 context 避免请求取消导致空值被长期缓存。 +func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { + // 断开请求取消链,避免客户端断连导致空值被长期缓存 + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) + defer cancel() + + channels, err := s.repo.ListAll(dbCtx) + if err != nil { + // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 + slog.Warn("failed to build channel cache", "error", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL + s.cache.Store(errorCache) + return nil, fmt.Errorf("list all channels: %w", err) + } + + // 收集所有 groupID,批量查询 platform + var allGroupIDs []int64 + for i := range channels { + allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) + } + groupPlatforms := make(map[int64]string) + if len(allGroupIDs) > 0 { + groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) + if err != nil { + slog.Warn("failed to load group platforms for channel cache", "error", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) + return nil, fmt.Errorf("get group platforms: %w", err) + } + } + + cache := newEmptyChannelCache() + cache.groupPlatform = groupPlatforms + cache.byID = make(map[int64]*Channel, len(channels)) + cache.loadedAt = time.Now() + + for i := range channels { + ch := &channels[i] + cache.byID[ch.ID] = ch + + for _, gid := range ch.GroupIDs { + cache.channelByGroupID[gid] = ch + platform := groupPlatforms[gid] + expandPricingToCache(cache, ch, gid, platform) + expandMappingToCache(cache, ch, gid, platform) + } + } + + // 通配符条目保持配置顺序(最先匹配到优先) + + s.cache.Store(cache) + return cache, nil +} + +// invalidateCache 使缓存失效,让下次读取时自然重建 + +// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。 +// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。 +func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool { + return groupPlatform == pricingPlatform +} + +// matchingPlatforms 返回分组平台对应的可匹配平台列表。 +// 各平台严格独立,只返回自身。 +func matchingPlatforms(groupPlatform string) []string { + return []string{groupPlatform} +} +func (s *ChannelService) invalidateCache() { + s.cache.Store((*channelCache)(nil)) + s.cacheSF.Forget("channel_cache") + + // 主动重建缓存,确保 CRUD 后立即生效 + if _, err := s.buildCache(context.Background()); err != nil { + slog.Warn("failed to rebuild channel cache after invalidation", "error", err) + } +} + +// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先) +func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing { + gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} + wildcards := c.wildcardByGroupPlatform[gpKey] + for _, wc := range wildcards { + if strings.HasPrefix(modelLower, wc.prefix) { + return wc.pricing + } + } + return nil +} + +// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先) +func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string { + gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} + wildcards := c.wildcardMappingByGP[gpKey] + for _, wc := range wildcards { + if strings.HasPrefix(modelLower, wc.prefix) { + return wc.target + } + } + return "" +} + +// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。 +// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。 +func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if pricing, ok := cache.pricingByGroupModel[key]; ok { + return pricing + } + } + // 精确查找全部失败,依次尝试通配符匹配 + for _, p := range matchingPlatforms(groupPlatform) { + if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil { + return pricing + } + } + return nil +} + +// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。 +// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。 +func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if mapped, ok := cache.mappingByGroupModel[key]; ok { + return mapped + } + } + for _, p := range matchingPlatforms(groupPlatform) { + if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" { + return mapped + } + } + return "" +} + +// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) +func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return nil, nil + } + + return ch.Clone(), nil +} + +// channelLookup 热路径公共查找结果 +type channelLookup struct { + cache *channelCache + channel *Channel + platform string +} + +// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。 +// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。 +func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return nil, nil + } + return &channelLookup{ + cache: cache, + channel: ch, + platform: cache.groupPlatform[groupID], + }, nil +} + +// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。 +// 各平台严格独立,只在本平台内查找定价。 +func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache", "group_id", groupID, "error", err) + return nil + } + if lk == nil { + return nil + } + + modelLower := strings.ToLower(model) + pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) + if pricing == nil { + return nil + } + + cp := pricing.Clone() + return &cp +} + +// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)) +// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。 +func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err) + } + if lk == nil { + return ChannelMappingResult{MappedModel: model} + } + return resolveMapping(lk, groupID, model) +} + +// IsModelRestricted 检查模型是否被渠道限制。 +// 返回 true 表示模型被限制(不在允许列表中)。 +// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 +func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + lk, _ := s.lookupGroupChannel(ctx, groupID) + if lk == nil { + return false + } + return checkRestricted(lk, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// restricted 始终返回 false,保留签名兼容性。 +func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if groupID == nil { + return ChannelMappingResult{MappedModel: model}, false + } + lk, _ := s.lookupGroupChannel(ctx, *groupID) + if lk == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return resolveMapping(lk, *groupID, model), false +} + +// resolveMapping 基于已查找的渠道信息解析模型映射。 +// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。 +func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { + result := ChannelMappingResult{ + MappedModel: model, + ChannelID: lk.channel.ID, + BillingModelSource: lk.channel.BillingModelSource, + } + if result.BillingModelSource == "" { + result.BillingModelSource = BillingModelSourceChannelMapped + } + + modelLower := strings.ToLower(model) + if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" { + result.MappedModel = mapped + result.Mapped = true + } + + return result +} + +// checkRestricted 基于已查找的渠道信息检查模型是否被限制。 +// 只在本平台的定价列表中查找。 +func checkRestricted(lk *channelLookup, groupID int64, model string) bool { + if !lk.channel.RestrictModels { + return false + } + modelLower := strings.ToLower(model) + // 使用与查找定价相同的跨平台逻辑 + if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil { + return false + } + return true +} + +// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。 +func ReplaceModelInBody(body []byte, newModel string) []byte { + if len(body) == 0 { + return body + } + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { + return body + } + newBody, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body + } + return newBody +} + +// --- CRUD --- + +// Create 创建渠道 +func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) { + exists, err := s.repo.ExistsByName(ctx, input.Name) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + + // 检查分组冲突 + if len(input.GroupIDs) > 0 { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + } + + channel := &Channel{ + Name: input.Name, + Description: input.Description, + Status: StatusActive, + BillingModelSource: input.BillingModelSource, + RestrictModels: input.RestrictModels, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, + } + if channel.BillingModelSource == "" { + channel.BillingModelSource = BillingModelSourceChannelMapped + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + return nil, err + } + + if err := s.repo.Create(ctx, channel); err != nil { + return nil, fmt.Errorf("create channel: %w", err) + } + + s.invalidateCache() + return s.repo.GetByID(ctx, channel.ID) +} + +// GetByID 获取渠道详情 +func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) { + return s.repo.GetByID(ctx, id) +} + +// Update 更新渠道 +func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) { + channel, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + channel.Name = input.Name + } + + if input.Description != nil { + channel.Description = *input.Description + } + + if input.Status != "" { + channel.Status = input.Status + } + + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + + // 检查分组冲突 + if input.GroupIDs != nil { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + channel.GroupIDs = *input.GroupIDs + } + + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + return nil, err + } + + // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 + var oldGroupIDs []int64 + if s.authCacheInvalidator != nil { + var err2 error + oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) + if err2 != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) + } + } + + if err := s.repo.Update(ctx, channel); err != nil { + return nil, fmt.Errorf("update channel: %w", err) + } + + s.invalidateCache() + + // 失效新旧分组的 auth 缓存 + if s.authCacheInvalidator != nil { + seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) + for _, gid := range oldGroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + for _, gid := range channel.GroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + } + + return s.repo.GetByID(ctx, id) +} + +// Delete 删除渠道 +func (s *ChannelService) Delete(ctx context.Context, id int64) error { + // 先获取关联分组用于失效缓存 + groupIDs, err := s.repo.GetGroupIDs(ctx, id) + if err != nil { + slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) + } + + if err := s.repo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete channel: %w", err) + } + + s.invalidateCache() + + if s.authCacheInvalidator != nil { + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + + return nil +} + +// List 获取渠道列表 +func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + return s.repo.List(ctx, params, status, search) +} + +// modelEntry 表示一个模型模式条目(用于冲突检测) +type modelEntry struct { + pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4") + prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样) + wildcard bool +} + +// conflictsBetween 检查两个模型模式是否冲突 +func conflictsBetween(a, b modelEntry) bool { + switch { + case !a.wildcard && !b.wildcard: + return a.prefix == b.prefix + case a.wildcard && !b.wildcard: + return strings.HasPrefix(b.prefix, a.prefix) + case !a.wildcard && b.wildcard: + return strings.HasPrefix(a.prefix, b.prefix) + default: + return strings.HasPrefix(a.prefix, b.prefix) || + strings.HasPrefix(b.prefix, a.prefix) + } +} + +// toModelEntry 将模型名转换为 modelEntry +func toModelEntry(pattern string) modelEntry { + lower := strings.ToLower(pattern) + isWild := strings.HasSuffix(lower, "*") + prefix := lower + if isWild { + prefix = strings.TrimSuffix(lower, "*") + } + return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} +} + +// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。 +// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。 +func validateNoConflictingModels(pricingList []ChannelModelPricing) error { + byPlatform := make(map[string][]modelEntry) + for _, p := range pricingList { + for _, model := range p.Models { + byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model)) + } + } + for platform, entries := range byPlatform { + if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil { + return err + } + } + return nil +} + +// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式 +func validateNoConflictingMappings(mapping map[string]map[string]string) error { + for platform, platformMapping := range mapping { + entries := make([]modelEntry, 0, len(platformMapping)) + for src := range platformMapping { + entries = append(entries, toModelEntry(src)) + } + if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil { + return err + } + } + return nil +} + +func validatePricingIntervals(pricingList []ChannelModelPricing) error { + for _, pricing := range pricingList { + if err := ValidateIntervals(pricing.Intervals); err != nil { + return infraerrors.BadRequest( + "INVALID_PRICING_INTERVALS", + fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v", + pricing.Platform, pricing.Models, err), + ) + } + } + return nil +} + +// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误 +func detectConflicts(entries []modelEntry, platform, errCode, label string) error { + for i := 0; i < len(entries); i++ { + for j := i + 1; j < len(entries); j++ { + if conflictsBetween(entries[i], entries[j]) { + return infraerrors.BadRequest(errCode, + fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range", + label, entries[i].pattern, entries[j].pattern, platform)) + } + } + } + return nil +} + +// --- Input types --- + +// CreateChannelInput 创建渠道输入 +type CreateChannelInput struct { + Name string + Description string + GroupIDs []int64 + ModelPricing []ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels bool +} + +// UpdateChannelInput 更新渠道输入 +type UpdateChannelInput struct { + Name string + Description *string + Status string + GroupIDs *[]int64 + ModelPricing *[]ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels *bool +} diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go new file mode 100644 index 00000000..3a01fd80 --- /dev/null +++ b/backend/internal/service/channel_service_test.go @@ -0,0 +1,2201 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock: ChannelRepository +// --------------------------------------------------------------------------- + +type mockChannelRepository struct { + listAllFn func(ctx context.Context) ([]Channel, error) + getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error) + createFn func(ctx context.Context, channel *Channel) error + getByIDFn func(ctx context.Context, id int64) (*Channel, error) + updateFn func(ctx context.Context, channel *Channel) error + deleteFn func(ctx context.Context, id int64) error + listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + existsByNameFn func(ctx context.Context, name string) (bool, error) + existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error) + getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error) + setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error + getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error) + getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + deleteModelPricingFn func(ctx context.Context, id int64) error + replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error { + if m.createFn != nil { + return m.createFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) { + if m.getByIDFn != nil { + return m.getByIDFn(ctx, id) + } + return nil, ErrChannelNotFound +} + +func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error { + if m.updateFn != nil { + return m.updateFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error { + if m.deleteFn != nil { + return m.deleteFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + if m.listFn != nil { + return m.listFn(ctx, params, status, search) + } + return nil, nil, nil +} + +func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) { + if m.listAllFn != nil { + return m.listAllFn(ctx) + } + return nil, nil +} + +func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + if m.existsByNameFn != nil { + return m.existsByNameFn(ctx, name) + } + return false, nil +} + +func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + if m.existsByNameExcludingFn != nil { + return m.existsByNameExcludingFn(ctx, name, excludeID) + } + return false, nil +} + +func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + if m.getGroupIDsFn != nil { + return m.getGroupIDsFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + if m.setGroupIDsFn != nil { + return m.setGroupIDsFn(ctx, channelID, groupIDs) + } + return nil +} + +func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + if m.getChannelIDByGroupIDFn != nil { + return m.getChannelIDByGroupIDFn(ctx, groupID) + } + return 0, nil +} + +func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if m.getGroupsInOtherChannelsFn != nil { + return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + if m.getGroupPlatformsFn != nil { + return m.getGroupPlatformsFn(ctx, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) { + if m.listModelPricingFn != nil { + return m.listModelPricingFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.createModelPricingFn != nil { + return m.createModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.updateModelPricingFn != nil { + return m.updateModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + if m.deleteModelPricingFn != nil { + return m.deleteModelPricingFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error { + if m.replaceModelPricingFn != nil { + return m.replaceModelPricingFn(ctx, channelID, pricingList) + } + return nil +} + +// --------------------------------------------------------------------------- +// Mock: APIKeyAuthCacheInvalidator +// --------------------------------------------------------------------------- + +type mockChannelAuthCacheInvalidator struct { + invalidatedGroupIDs []int64 + invalidatedKeys []string + invalidatedUserIDs []int64 +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) { + m.invalidatedKeys = append(m.invalidatedKeys, key) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) { + m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestChannelService(repo *mockChannelRepository) *ChannelService { + return NewChannelService(repo, nil) +} + +func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { + return NewChannelService(repo, auth) +} + +// makeStandardRepo returns a repo that serves one active channel with anthropic pricing +// for group 1, with the given model pricing and model mapping. +func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository { + return &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return groupPlatforms, nil + }, + } +} + +// =========================================================================== +// 1. BuildModelMappingChain +// =========================================================================== + +func TestBuildModelMappingChain(t *testing.T) { + tests := []struct { + name string + result ChannelMappingResult + requestModel string + upstreamModel string + want string + }{ + { + name: "no mapping, no upstream diff", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4", + want: "", + }, + { + name: "no mapping, upstream differs", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream differs", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "my-model", + upstreamModel: "actual-upstream", + want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream", + }, + { + name: "mapped, upstream same as mapped", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream empty", + result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"}, + requestModel: "my-model", + upstreamModel: "", + want: "my-model\u2192target-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel) + require.Equal(t, tt.want, got) + }) + } +} + +// =========================================================================== +// 2. ReplaceModelInBody +// =========================================================================== + +func TestReplaceModelInBody(t *testing.T) { + tests := []struct { + name string + body []byte + newModel string + check func(t *testing.T, result []byte) + }{ + { + name: "empty body", + body: []byte{}, + newModel: "new-model", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte{}, result) + }, + }, + { + name: "model already equal", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-sonnet-4", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result) + }, + }, + { + name: "model different", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + { + name: "no model field", + body: []byte(`{"temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ReplaceModelInBody(tt.body, tt.newModel) + tt.check(t, result) + }) + } +} + +// =========================================================================== +// 3. validateNoConflictingModels + validateNoConflictingMappings +// =========================================================================== + +func TestValidateNoConflictingModels(t *testing.T) { + tests := []struct { + name string + pricingList []ChannelModelPricing + wantErr bool + errContains string + }{ + { + name: "no duplicates", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}}, + {Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + wantErr: false, + }, + { + name: "same platform duplicate", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + wantErr: true, + errContains: "claude-sonnet-4", + }, + { + name: "same model different platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"model-a"}}, + {Platform: "openai", Models: []string{"model-a"}}, + }, + wantErr: false, + }, + { + name: "case insensitive", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"Claude"}}, + {Platform: "anthropic", Models: []string{"claude"}}, + }, + wantErr: true, + }, + { + name: "empty list (nil)", + pricingList: nil, + wantErr: false, + }, + { + name: "wildcard_vs_wildcard_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard_vs_exact_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "no_conflict_different_platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "openai", Models: []string{"claude-*"}}, + }, + wantErr: false, + }, + { + name: "no_conflict_same_platform_different_prefix", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "anthropic", Models: []string{"gpt-*"}}, + }, + wantErr: false, + }, + { + name: "catch_all_wildcard_conflicts_with_everything", + pricingList: []ChannelModelPricing{ + {Platform: "openai", Models: []string{"*"}}, + {Platform: "openai", Models: []string{"gpt-5"}}, + }, + wantErr: true, + errContains: "conflict", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingModels(tt.pricingList) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } + + // Additional sub-case: explicit empty slice + t.Run("empty list (empty slice)", func(t *testing.T) { + err := validateNoConflictingModels([]ChannelModelPricing{}) + require.NoError(t, err) + }) +} + +func TestValidateNoConflictingMappings(t *testing.T) { + tests := []struct { + name string + mapping map[string]map[string]string + wantErr bool + errContains string + }{ + { + name: "nil mapping", + mapping: nil, + wantErr: false, + }, + { + name: "empty mapping", + mapping: map[string]map[string]string{}, + wantErr: false, + }, + { + name: "no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"}, + }, + wantErr: false, + }, + { + name: "wildcard vs wildcard conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a", "claude-opus-*": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard vs exact conflict", + mapping: map[string]map[string]string{ + "openai": {"gpt-*": "a", "gpt-4o": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "exact duplicate conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-4": "a"}, + "openai": {"claude-opus-4": "b"}, + }, + wantErr: false, // different platforms + }, + { + name: "different platforms no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a"}, + "openai": {"claude-*": "b"}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingMappings(tt.mapping) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestConflictsBetween(t *testing.T) { + tests := []struct { + name string + a, b modelEntry + want bool + }{ + { + name: "exact same", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact different", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "gpt-4o", wildcard: false}, + want: false, + }, + { + name: "wildcard matches exact", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact does not match unrelated wildcard", + a: modelEntry{prefix: "gpt-4o", wildcard: false}, + b: modelEntry{prefix: "claude-", wildcard: true}, + want: false, + }, + { + name: "wildcard prefix overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-", wildcard: true}, + want: true, + }, + { + name: "wildcards no overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "gpt-", wildcard: true}, + want: false, + }, + { + name: "catch-all wildcard vs any", + a: modelEntry{prefix: "", wildcard: true}, + b: modelEntry{prefix: "anything", wildcard: false}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b)) + }) + } +} + +// =========================================================================== +// 4. Cache Building + Hot Path Methods +// =========================================================================== + +// --- 4.1 GetChannelForGroup --- + +func TestGetChannelForGroup_Success(t *testing.T) { + ch := Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int64(1), result.ID) + require.Equal(t, "test-channel", result.Name) + + // returned value should be a clone + result.Name = "mutated" + result2, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Equal(t, "test-channel", result2.Name) +} + +func TestGetChannelForGroup_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 999) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_CacheError(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("db connection failed") + }, + } + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "db connection failed") +} + +// --- 4.2 GetChannelModelPricing --- + +func TestGetChannelModelPricing_ExactMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + require.InDelta(t, 15e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +func TestGetChannelModelPricing_WildcardMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) +} + +func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514") + require.NotNil(t, result) + // "claude-*" is defined first, so it matches first regardless of prefix length + require.Equal(t, int64(200), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_NoMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"}) + svc := newTestChannelService(repo) + + // Group 10 (anthropic) should NOT see openai pricing + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) + + // Group 10 (anthropic) should see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) + + // Group 20 (openai) should see openai pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // Group 20 (openai) should NOT see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + + // Mutate the returned pricing's slice fields — original cache should not be affected + // (Clone copies slices independently, pointer fields are shared per design) + result.Models = append(result.Models, "hacked") + result.ID = 999 + + // Original cache should not be affected (slice independence + struct copy) + result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result2) + require.Equal(t, 1, len(result2.Models)) + require.Equal(t, int64(100), result2.ID) +} + +// --- 4.3 ResolveChannelMapping --- + +func TestResolveChannelMapping_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4") + require.Equal(t, "claude-opus-4", result.MappedModel) + require.False(t, result.Mapped) + require.Equal(t, int64(0), result.ChannelID) +} + +func TestResolveChannelMapping_ExactMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_WildcardMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "*": "gpt-5.4", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name") + require.True(t, result.Mapped) + require.Equal(t, "gpt-5.4", result.MappedModel) +} + +func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-*": "target2", + "claude-sonnet-*": "target1", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + // map iteration order is non-deterministic, so the first-match depends on + // insertion order which Go maps don't guarantee; verify that one of the + // wildcard targets matched + require.Contains(t, []string{"target1", "target2"}, result.MappedModel) +} + +func TestResolveChannelMapping_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-opus-4", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: "", // empty + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource) +} + +func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: BillingModelSourceUpstream, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource) +} + +func TestResolveChannelMapping_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-sonnet-4", result.MappedModel) + require.Equal(t, int64(0), result.ChannelID) // no channel +} + +// --- 4.4 IsModelRestricted --- + +func TestIsModelRestricted_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_RestrictDisabled(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: false, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Even though model is not in pricing, RestrictModels=false + restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "any-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInPricing(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInWildcard(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelNotFound(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1") + require.True(t, restricted) +} + +func TestIsModelRestricted_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4") + require.False(t, restricted) +} + +// --- 4.5 ResolveChannelMappingAndRestrict --- +// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。 + +func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4") + require.False(t, restricted) + require.False(t, mapping.Mapped) + require.Equal(t, "claude-opus-4", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4") + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 + require.True(t, mapping.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 + require.False(t, mapping.Mapped) + require.Equal(t, "unknown-model", mapping.MappedModel) +} + +// --- 4.6 Cache Building Specifics --- + +func TestBuildCache_DBError(t *testing.T) { + callCount := 0 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return nil, errors.New("database down") + }, + } + svc := newTestChannelService(repo) + + // First call should fail + _, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Contains(t, err.Error(), "database down") + require.Equal(t, 1, callCount) + + // Second call within error-TTL should use error cache, but still return error + // Because buildCache stores error-TTL cache and returns error, the cached value + // is still within TTL and loadCache returns it (which is an empty cache). + // Actually, re-reading the code: buildCache returns nil, err, and the error cache + // only serves as a "don't retry immediately" mechanism. The singleflight.Do + // returns the error. On next call within error-TTL, the cache has an empty but + // valid entry, so loadCache returns it (with empty maps). GetChannelForGroup + // will find nothing and return nil, nil. + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) + // Should NOT have hit DB again (error-TTL cache is active) + require.Equal(t, 1, callCount) +} + +func TestBuildCache_GroupPlatformError(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, errors.New("group platforms failed") + }, + } + svc := newTestChannelService(repo) + + // Should fail-close: error propagated when group platforms cannot be loaded + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Nil(t, result) + + // Within error-TTL, second call should hit cache (empty) and return nil, nil + result2, err2 := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err2) + require.Nil(t, result2) +} + +func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20, 30}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "anthropic", + 30: "anthropic", + }) + svc := newTestChannelService(repo) + + for _, gid := range []int64{10, 20, 30} { + result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4") + require.NotNil(t, result, "group %d should have pricing", gid) + require.Equal(t, int64(100), result.ID) + } +} + +func TestBuildCache_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "openai", + }) + svc := newTestChannelService(repo) + + // anthropic group sees only anthropic models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")) + + // openai group sees only openai models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")) +} + +func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + // Configuration order: shortest prefix first + {ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config + result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first + result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-other" only matches "c-*" (ID=100) + result = svc.GetChannelModelPricing(context.Background(), 10, "c-other") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +// --- 4.7 invalidateCache --- + +func TestInvalidateCache(t *testing.T) { + callCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + } + svc := newTestChannelService(repo) + + // First load + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) + + // Second call should use cache + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) // no new DB call + + // Invalidate + svc.invalidateCache() + + // Next call should rebuild from DB + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 2, callCount) // rebuilt +} + +// =========================================================================== +// 5. CRUD Methods +// =========================================================================== + +// --- 5.1 Create --- + +func TestCreate_Success(t *testing.T) { + createdID := int64(42) + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10}, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) +} + +func TestCreate_NameExists(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return true, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "existing-channel", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestCreate_GroupConflict(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{10}, nil // group 10 already in another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10, 20}, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestCreate_DuplicateModel(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestCreate_InvalidPricingIntervals(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "overlap") +} + +func TestCreate_DefaultBillingModelSource(t *testing.T) { + var capturedChannel *Channel + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + ch.ID = 1 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return capturedChannel, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + BillingModelSource: "", // empty, should default to "channel_mapped" + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource) +} + +func TestCreate_InvalidatesCache(t *testing.T) { + loadCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, c *Channel) error { + c.ID = 2 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new", Status: StatusActive}, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 1, loadCount) + + // Create triggers cache invalidation + _, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"}) + require.NoError(t, err) + + // Next cache access should rebuild + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 2, loadCount) +} + +// --- 5.2 Update --- + +func TestUpdate_Success(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "updated-name", + Description: testPtrString("new desc"), + }) + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestUpdate_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return nil, ErrChannelNotFound + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 999, &UpdateChannelInput{ + Name: "whatever", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "channel") +} + +func TestUpdate_NameConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) { + return true, nil // name conflicts with another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "conflicting-name", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestUpdate_GroupConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{20}, nil // group 20 in another channel + }, + } + svc := newTestChannelService(repo) + + newGroupIDs := []int64{10, 20} + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + GroupIDs: &newGroupIDs, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestUpdate_DuplicateModel(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + dupPricing := []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &dupPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestUpdate_InvalidPricingIntervals(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + invalidPricing := []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &invalidPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "unbounded") +} + +func TestUpdate_InvalidatesChannelCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + loadCount := 0 + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{*existing}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Channel cache should be invalidated (next access rebuilds) + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestUpdate_InvalidatesAuthCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + auth := &mockChannelAuthCacheInvalidator{} + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Auth cache should be invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) +} + +// --- 5.3 Delete --- + +func TestChannelDelete_Success(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +func TestChannelDelete_InvalidatesCaches(t *testing.T) { + auth := &mockChannelAuthCacheInvalidator{} + loadCount := 0 + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + + // Auth cache invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) + + // Channel cache invalidated + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestChannelDelete_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return errors.New("record not found") + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 999) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") +} + +// =========================================================================== +// 6. Edge Case Tests +// =========================================================================== + +// --- 6.1 Create with empty GroupIDs --- + +func TestCreate_NoGroups(t *testing.T) { + createdID := int64(55) + getGroupsInOtherChannelsCalled := false + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + getGroupsInOtherChannelsCalled = true + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "no-groups-channel", + GroupIDs: []int64{}, // empty slice + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) + // GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0) + require.False(t, getGroupsInOtherChannelsCalled) +} + +// --- 6.2 Update only Status --- + +func TestUpdate_StatusOnly(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + } + var capturedChannel *Channel + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Status: StatusDisabled, + }) + require.NoError(t, err) + require.NotNil(t, result) + // Verify that the channel passed to repo.Update has the new status + require.NotNil(t, capturedChannel) + require.Equal(t, StatusDisabled, capturedChannel.Status) + // Name should remain unchanged + require.Equal(t, "test-channel", capturedChannel.Name) +} + +// --- 6.3 Delete when GetGroupIDs fails --- + +func TestChannelDelete_GetGroupIDsError(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, errors.New("group IDs lookup failed") + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591) + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +// --- 6.4 ReplaceModelInBody with invalid JSON --- + +func TestReplaceModelInBody_InvalidJSON(t *testing.T) { + // Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set + // (no panic, no error from sjson, but result is mutated garbage) + brokenBody := []byte("{broken") + result := ReplaceModelInBody(brokenBody, "new-model") + require.NotNil(t, result) + // sjson does not error on this input, so result differs from original — just verify no panic + + // Case 2: JSON array — sjson.SetBytes returns error on non-object, + // triggering the L447 error fallback path that returns original body. + arrayBody := []byte("[]") + result2 := ReplaceModelInBody(arrayBody, "new-model") + require.Equal(t, arrayBody, result2) +} + +// =========================================================================== +// 7. isPlatformPricingMatch +// =========================================================================== + +func TestIsPlatformPricingMatch(t *testing.T) { + tests := []struct { + name string + groupPlatform string + pricingPlatform string + want bool + }{ + {"antigravity does NOT match anthropic", PlatformAntigravity, PlatformAnthropic, false}, + {"antigravity does NOT match gemini", PlatformAntigravity, PlatformGemini, false}, + {"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true}, + {"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false}, + {"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true}, + {"anthropic does NOT match antigravity", PlatformAnthropic, PlatformAntigravity, false}, + {"anthropic does NOT match gemini", PlatformAnthropic, PlatformGemini, false}, + {"gemini matches gemini", PlatformGemini, PlatformGemini, true}, + {"gemini does NOT match antigravity", PlatformGemini, PlatformAntigravity, false}, + {"gemini does NOT match anthropic", PlatformGemini, PlatformAnthropic, false}, + {"empty string matches nothing", "", PlatformAnthropic, false}, + {"empty string matches empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isPlatformPricingMatch(tt.groupPlatform, tt.pricingPlatform)) + }) + } +} + +// =========================================================================== +// 8. matchingPlatforms +// =========================================================================== + +func TestMatchingPlatforms(t *testing.T) { + tests := []struct { + name string + groupPlatform string + want []string + }{ + {"antigravity returns itself only", PlatformAntigravity, []string{PlatformAntigravity}}, + {"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}}, + {"gemini returns itself", PlatformGemini, []string{PlatformGemini}}, + {"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchingPlatforms(tt.groupPlatform) + require.Equal(t, tt.want, result) + }) + } +} + +// =========================================================================== +// 9. Antigravity platform isolation — no cross-platform pricing leakage +// =========================================================================== + +func TestGetChannelModelPricing_AntigravityDoesNotSeeCrossPlatformPricing(t *testing.T) { + // Channel has anthropic pricing for claude-opus-4-6. + // Group 10 is antigravity — should NOT see the anthropic pricing. + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: PlatformAnthropic, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6") + require.Nil(t, result, "antigravity group should NOT see anthropic-platform pricing") +} + +func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) { + // Channel has antigravity-platform pricing for claude-opus-4-6. + // Group 10 is anthropic — should NOT see antigravity pricing (no cross-platform leakage). + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: PlatformAntigravity, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6") + require.Nil(t, result, "anthropic group should NOT see antigravity-platform pricing") +} + +// =========================================================================== +// 10. Antigravity platform isolation — no cross-platform model mapping +// =========================================================================== + +func TestResolveChannelMapping_AntigravityDoesNotSeeCrossPlatformMapping(t *testing.T) { + // Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6. + // Group 10 is antigravity — should NOT apply the anthropic mapping. + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: { + "claude-opus-4-5": "claude-opus-4-6", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5") + require.False(t, result.Mapped, "antigravity group should NOT apply anthropic mapping") + require.Equal(t, "claude-opus-4-5", result.MappedModel) +} + +// =========================================================================== +// 11. Antigravity platform isolation — same-name model across platforms +// =========================================================================== + +func TestGetChannelModelPricing_AntigravityDoesNotSeeSameModelFromOtherPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。 + // antigravity 分组不应看到任何一个(各平台严格独立)。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.Nil(t, result, "antigravity group should NOT see anthropic/gemini-platform pricing") +} + +func TestGetChannelModelPricing_AntigravityDoesNotSeeGeminiOnlyPricing(t *testing.T) { + // 只有 gemini 平台定义了模型 "gemini-model"。 + // antigravity 分组不应看到 gemini 的定价。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model") + require.Nil(t, result, "antigravity group should NOT see gemini-platform pricing") +} + +func TestGetChannelModelPricing_AntigravityDoesNotSeeWildcardFromOtherPlatforms(t *testing.T) { + // anthropic 和 gemini 都有 "shared-*" 通配符定价。 + // antigravity 分组不应命中任何一个。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.Nil(t, result, "antigravity group should NOT see wildcard pricing from other platforms") +} + +func TestResolveChannelMapping_AntigravityDoesNotSeeMappingFromOtherPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。 + // antigravity 分组不应命中任何一个。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: {"alias": "anthropic-target"}, + PlatformGemini: {"alias": "gemini-target"}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "alias") + require.False(t, result.Mapped, "antigravity group should NOT see mapping from other platforms") + require.Equal(t, "alias", result.MappedModel) +} + +func TestCheckRestricted_AntigravityDoesNotSeeModelsFromOtherPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model"。 + // antigravity 分组启用了 RestrictModels,"shared-model" 应被限制(各平台独立)。 + ch := Channel{ + ID: 1, + Status: StatusActive, + RestrictModels: true, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model") + require.True(t, restricted, "shared-model from other platforms should be restricted for antigravity") + + restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model") + require.True(t, restricted, "unknown-model should be restricted for antigravity") +} + +func TestGetChannelModelPricing_AntigravityOwnPricingWorks(t *testing.T) { + // antigravity 平台自己配置的定价应正常生效(覆盖 Claude 和 Gemini 模型)。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 600, Platform: PlatformAntigravity, Models: []string{"claude-*"}, InputPrice: testPtrFloat64(15e-6)}, + {ID: 601, Platform: PlatformAntigravity, Models: []string{"gemini-*"}, InputPrice: testPtrFloat64(2e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + // Claude 模型匹配 antigravity 定价 + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4") + require.NotNil(t, result) + require.Equal(t, int64(600), result.ID) + require.InDelta(t, 15e-6, *result.InputPrice, 1e-12) + + // Gemini 模型匹配 antigravity 定价 + result = svc.GetChannelModelPricing(context.Background(), 10, "gemini-2.5-flash") + require.NotNil(t, result) + require.Equal(t, int64(601), result.ID) + require.InDelta(t, 2e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { + // 确保非 antigravity 平台的行为不受影响。 + // anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini}) + svc := newTestChannelService(repo) + + // anthropic 分组应该只看到 anthropic 的定价 + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(600), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) + + // gemini 分组应该只看到 gemini 的定价 + result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(601), result.ID) + require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) +} diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go new file mode 100644 index 00000000..deac64d6 --- /dev/null +++ b/backend/internal/service/channel_test.go @@ -0,0 +1,435 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetModelPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)}, + {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, + }, + } + + tests := []struct { + name string + model string + wantID int64 + wantNil bool + }{ + {"exact match", "claude-sonnet-4", 1, false}, + {"case insensitive", "Claude-Sonnet-4", 1, false}, + {"not found", "gemini-3.1-pro", 0, true}, + {"wildcard pattern not matched", "claude-opus-4-20250514", 0, true}, + {"per_request model", "gpt-5.1", 3, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ch.GetModelPricing(tt.model) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantID, result.ID) + }) + } +} + +func TestGetModelPricing_ReturnsCopy(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)}, + }, + } + + result := ch.GetModelPricing("claude-sonnet-4") + require.NotNil(t, result) + + // Modify the returned copy's slice — original should be unchanged + result.Models = append(result.Models, "hacked") + + // Original should be unchanged + require.Equal(t, 1, len(ch.ModelPricing[0].Models)) +} + +func TestGetModelPricing_EmptyPricing(t *testing.T) { + ch := &Channel{ModelPricing: nil} + require.Nil(t, ch.GetModelPricing("any-model")) + + ch2 := &Channel{ModelPricing: []ChannelModelPricing{}} + require.Nil(t, ch2.GetModelPricing("any-model")) +} + +func TestGetIntervalForContext(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + } + + tests := []struct { + name string + tokens int + wantPrice *float64 + wantNil bool + }{ + {"first interval", 50000, testPtrFloat64(1e-6), false}, + // (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个 + {"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false}, + // 128001 > 128000,匹配第二个区间 + {"boundary: just above first max", 128001, testPtrFloat64(2e-6), false}, + {"unbounded interval", 500000, testPtrFloat64(2e-6), false}, + // (0, max] — 0 不匹配任何区间(左开) + {"zero tokens: no match", 0, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetIntervalForContext(tt.tokens) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12) + }) + } +} + +func TestGetIntervalForContext_NoMatch(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: testPtrInt(50000)}, + }, + } + require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min + require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open) + require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed) + require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000 +} + +func TestGetIntervalForContext_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetIntervalForContext(1000)) +} + +func TestGetTierByLabel(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)}, + }, + } + + tests := []struct { + name string + label string + wantNil bool + want float64 + }{ + {"exact match", "1K", false, 0.04}, + {"case insensitive", "hd", false, 0.12}, + {"not found", "4K", true, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetTierByLabel(tt.label) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12) + }) + } +} + +func TestGetTierByLabel_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetTierByLabel("1K")) +} + +func TestChannelClone(t *testing.T) { + original := &Channel{ + ID: 1, + Name: "test", + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"model-a"}, + InputPrice: testPtrFloat64(5e-6), + }, + }, + } + + cloned := original.Clone() + require.NotNil(t, cloned) + require.Equal(t, original.ID, cloned.ID) + require.Equal(t, original.Name, cloned.Name) + + // Modify clone slices — original should not change + cloned.GroupIDs[0] = 999 + require.Equal(t, int64(10), original.GroupIDs[0]) + + cloned.ModelPricing[0].Models[0] = "hacked" + require.Equal(t, "model-a", original.ModelPricing[0].Models[0]) +} + +func TestChannelClone_Nil(t *testing.T) { + var ch *Channel + require.Nil(t, ch.Clone()) +} + +func TestChannelModelPricingClone(t *testing.T) { + original := ChannelModelPricing{ + Models: []string{"a", "b"}, + Intervals: []PricingInterval{ + {MinTokens: 0, TierLabel: "tier1"}, + }, + } + + cloned := original.Clone() + + // Modify clone slices — original unchanged + cloned.Models[0] = "hacked" + require.Equal(t, "a", original.Models[0]) + + cloned.Intervals[0].TierLabel = "hacked" + require.Equal(t, "tier1", original.Intervals[0].TierLabel) +} + +// --- BillingMode.IsValid --- + +func TestBillingModeIsValid(t *testing.T) { + tests := []struct { + name string + mode BillingMode + want bool + }{ + {"token", BillingModeToken, true}, + {"per_request", BillingModePerRequest, true}, + {"image", BillingModeImage, true}, + {"empty", BillingMode(""), true}, + {"unknown", BillingMode("unknown"), false}, + {"random", BillingMode("xyz"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.mode.IsValid()) + }) + } +} + +// --- Channel.IsActive --- + +func TestChannelIsActive(t *testing.T) { + tests := []struct { + name string + status string + want bool + }{ + {"active", StatusActive, true}, + {"disabled", "disabled", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := &Channel{Status: tt.status} + require.Equal(t, tt.want, ch.IsActive()) + }) + } +} + +// --- ChannelModelPricing.Clone edge cases --- + +func TestChannelModelPricingClone_EdgeCases(t *testing.T) { + t.Run("nil models", func(t *testing.T) { + original := ChannelModelPricing{Models: nil} + cloned := original.Clone() + require.Nil(t, cloned.Models) + }) + + t.Run("nil intervals", func(t *testing.T) { + original := ChannelModelPricing{Intervals: nil} + cloned := original.Clone() + require.Nil(t, cloned.Intervals) + }) + + t.Run("empty models", func(t *testing.T) { + original := ChannelModelPricing{Models: []string{}} + cloned := original.Clone() + require.NotNil(t, cloned.Models) + require.Empty(t, cloned.Models) + }) +} + +// --- Channel.Clone edge cases --- + +func TestChannelClone_EdgeCases(t *testing.T) { + t.Run("nil model mapping", func(t *testing.T) { + original := &Channel{ID: 1, ModelMapping: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelMapping) + }) + + t.Run("nil model pricing", func(t *testing.T) { + original := &Channel{ID: 1, ModelPricing: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelPricing) + }) + + t.Run("deep copy model mapping", func(t *testing.T) { + original := &Channel{ + ID: 1, + ModelMapping: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + } + cloned := original.Clone() + + // Modify the cloned nested map + cloned.ModelMapping["openai"]["gpt-4"] = "hacked" + + // Original must remain unchanged + require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"]) + }) +} + +// --- ValidateIntervals --- + +func TestValidateIntervals_Empty(t *testing.T) { + require.NoError(t, ValidateIntervals(nil)) + require.NoError(t, ValidateIntervals([]PricingInterval{})) +} + +func TestValidateIntervals_ValidIntervals(t *testing.T) { + tests := []struct { + name string + intervals []PricingInterval + }{ + { + name: "single bounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "two intervals with gap", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "two contiguous intervals", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "unsorted input (auto-sorted by validator)", + intervals: []PricingInterval{ + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "single unbounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, ValidateIntervals(tt.intervals)) + }) + } +} + +func TestValidateIntervals_NegativeMinTokens(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "min_tokens") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_MaxTokensZero(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> 0") +} + +func TestValidateIntervals_MaxLessThanMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_MaxEqualsMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_NegativePrice(t *testing.T) { + negPrice := -0.01 + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "input_price") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_OverlappingIntervals(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "overlap") +} + +func TestValidateIntervals_UnboundedNotLast(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "unbounded") + require.Contains(t, err.Error(), "last") +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ecac0db0..52df52d6 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,7 +24,6 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity - PlatformSora = domain.PlatformSora ) // Account type constants @@ -107,7 +106,6 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 - SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -199,27 +197,6 @@ const ( // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. SettingKeyBetaPolicySettings = "beta_policy_settings" - // ========================= - // Sora S3 存储配置 - // ========================= - - SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 - SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 - SettingKeySoraS3Region = "sora_s3_region" // S3 区域 - SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 - SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID - SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) - SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 - SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) - SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) - SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) - - // ========================= - // Sora 用户存储配额 - // ========================= - - SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) - // ========================= // Claude Code Version Check // ========================= diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 6e19db32..e7661aad 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -761,7 +761,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock( system := gjson.GetBytes(upstream.lastBody, "system") require.True(t, system.Exists()) - require.Contains(t, system.Raw, "x-anthropic-billing-header keep") + require.Equal(t, claudeCodeSystemPrompt, system.String()) + + // 原始 system prompt 应迁移至 messages 中 + messages := gjson.GetBytes(upstream.lastBody, "messages") + require.True(t, messages.IsArray()) + firstMsg := messages.Array()[0] + require.Equal(t, "user", firstMsg.Get("role").String()) + require.Contains(t, firstMsg.Get("content.0.text").String(), "x-anthropic-billing-header keep") }) } } diff --git a/backend/internal/service/gateway_channel_restriction_fallback_test.go b/backend/internal/service/gateway_channel_restriction_fallback_test.go new file mode 100644 index 00000000..d3196419 --- /dev/null +++ b/backend/internal/service/gateway_channel_restriction_fallback_test.go @@ -0,0 +1,130 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(1), account.ID) +} + +func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID) +} diff --git a/backend/internal/service/gateway_channel_restriction_test.go b/backend/internal/service/gateway_channel_restriction_test.go new file mode 100644 index 00000000..3a2ad2ff --- /dev/null +++ b/backend/internal/service/gateway_channel_restriction_test.go @@ -0,0 +1,293 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- billingModelForRestriction --- + +func TestBillingModelForRestriction_Requested(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-5", got) +} + +func TestBillingModelForRestriction_ChannelMapped(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got) +} + +func TestBillingModelForRestriction_Upstream(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "", got, "upstream should return empty (per-account check needed)") +} + +func TestBillingModelForRestriction_Empty(t *testing.T) { + t.Parallel() + got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped") +} + +// --- resolveAccountUpstreamModel --- + +func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAntigravity, + } + // Antigravity 平台使用 DefaultAntigravityModelMapping + got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got) +} + +func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAntigravity, + } + got := resolveAccountUpstreamModel(account, "totally-unknown-model") + require.Equal(t, "", got, "unsupported model should return empty") +} + +func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAnthropic, + } + got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough") +} + +// --- checkChannelPricingRestriction --- + +func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) { + t.Parallel() + svc := &GatewayService{channelService: &ChannelService{}} + require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4")) +} + +func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) { + t.Parallel() + svc := &GatewayService{} + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4")) +} + +func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) { + t.Parallel() + svc := &GatewayService{channelService: &ChannelService{}} + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "")) +} + +func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) { + t.Parallel() + // 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,但定价列表只有 claude-opus-4-6 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "mapped model claude-sonnet-4-6 is NOT in pricing → restricted") +} + +func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) { + t.Parallel() + // 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,定价列表包含 claude-sonnet-4-6 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "mapped model claude-sonnet-4-6 IS in pricing → allowed") +} + +func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) { + t.Parallel() + // billing_model_source=requested,定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceRequested, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "requested model claude-sonnet-4-5 is NOT in pricing → restricted") +} + +func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceRequested, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "requested model IS in pricing → allowed") +} + +func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) { + t.Parallel() + // upstream 模式:预检查始终跳过(返回 false),需逐账号检查 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"), + "upstream mode should skip pre-check (per-account check needed)") +} + +func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: false, // 未开启模型限制 + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"), + "RestrictModels=false → always allowed") +} + +func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) { + t.Parallel() + // 分组没有关联渠道 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil }, + } + channelSvc := newTestChannelService(repo) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(999) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"), + "no channel for group → allowed") +} + +// --- isUpstreamModelRestrictedByChannel --- + +func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + // claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6 + // 但定价列表只有 claude-opus-4-6 + require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"), + "upstream model claude-sonnet-4-6 NOT in pricing → restricted") +} + +func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"), + "upstream model claude-sonnet-4-6 IS in pricing → allowed") +} + +func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + // totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空 + require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"), + "unmappable model → upstream model empty → not restricted (account filter handles this)") +} diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go index 161c4ba4..e5bf49b8 100644 --- a/backend/internal/service/gateway_hotpath_optimization_test.go +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { modelsListCacheTTL: time.Minute, } - result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index f28912bb..72832837 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // No concurrency service } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // legacy path } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { } excludedIDs := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.Error(t, err) require.Nil(t, result) require.ErrorIs(t, err, ErrNoAvailableAccounts) @@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { } excluded := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.Error(t, err) require.Nil(t, result) require.ErrorIs(t, err, ErrClaudeCodeOnly) @@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -3139,7 +3139,7 @@ func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 0, groupRepo.getByIDLiteCalls) } @@ -3182,7 +3182,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } @@ -3252,7 +3252,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) require.NoError(t, err) require.NotNil(t, account) - require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDCalls) // +1 for require_privacy_set check require.Equal(t, 1, groupRepo.getByIDLiteCalls) } diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index 356536b0..d0f5a8c0 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -278,3 +278,141 @@ func TestInjectClaudeCodePrompt(t *testing.T) { }) } } + +func TestRewriteSystemForNonClaudeCode(t *testing.T) { + tests := []struct { + name string + body string + system any + wantSystemStr string // system 应为纯字符串 + wantMessagesLen int // messages 数组长度 + wantFirstMsgRole string // 第一条消息的 role + wantFirstMsgText string // 第一条消息的 content[0].text + wantAckMsgText string // 第二条消息的 content[0].text + }{ + { + name: "nil system - no messages injected", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: nil, + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 1, // 原始 1 条消息,不注入 + }, + { + name: "empty string system - no messages injected", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: "", + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 1, + }, + { + name: "custom string system - migrated to messages", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: "You are a personal assistant running inside OpenClaw.", + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 3, // instruction + ack + original + wantFirstMsgRole: "user", + wantFirstMsgText: "[System Instructions]\nYou are a personal assistant running inside OpenClaw.", + wantAckMsgText: "Understood. I will follow these instructions.", + }, + { + name: "system equals Claude Code prompt - no messages injected", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: claudeCodeSystemPrompt, + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 1, + }, + { + name: "array system with custom blocks - text joined and migrated", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: []any{ + map[string]any{"type": "text", "text": "First instruction"}, + map[string]any{"type": "text", "text": "Second instruction"}, + }, + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 3, + wantFirstMsgRole: "user", + wantFirstMsgText: "[System Instructions]\nFirst instruction\n\nSecond instruction", + wantAckMsgText: "Understood. I will follow these instructions.", + }, + { + name: "empty array system - no messages injected", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: []any{}, + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 1, + }, + { + name: "json.RawMessage string system", + body: `{"model":"claude-3","system":"Custom prompt","messages":[{"role":"user","content":"hello"}]}`, + system: json.RawMessage(`"Custom prompt"`), + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 3, + wantFirstMsgRole: "user", + wantFirstMsgText: "[System Instructions]\nCustom prompt", + wantAckMsgText: "Understood. I will follow these instructions.", + }, + { + name: "json.RawMessage nil system", + body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, + system: json.RawMessage(nil), + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 1, + }, + { + name: "multiple original messages preserved", + body: `{"model":"claude-3","messages":[{"role":"user","content":"msg1"},{"role":"assistant","content":"resp1"},{"role":"user","content":"msg2"}]}`, + system: "Be helpful", + wantSystemStr: claudeCodeSystemPrompt, + wantMessagesLen: 5, // 2 injected + 3 original + wantFirstMsgRole: "user", + wantFirstMsgText: "[System Instructions]\nBe helpful", + wantAckMsgText: "Understood. I will follow these instructions.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := rewriteSystemForNonClaudeCode([]byte(tt.body), tt.system) + + var parsed map[string]any + err := json.Unmarshal(result, &parsed) + require.NoError(t, err) + + // system 应为纯字符串 + systemVal, ok := parsed["system"].(string) + require.True(t, ok, "system should be a string, got %T", parsed["system"]) + require.Equal(t, tt.wantSystemStr, systemVal) + + // 检查 messages + messages, ok := parsed["messages"].([]any) + require.True(t, ok, "messages should be an array") + require.Len(t, messages, tt.wantMessagesLen) + + if tt.wantFirstMsgRole != "" && len(messages) >= 2 { + // 检查注入的 instruction 消息 + firstMsg, ok := messages[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantFirstMsgRole, firstMsg["role"]) + + firstContent, ok := firstMsg["content"].([]any) + require.True(t, ok) + require.Len(t, firstContent, 1) + firstBlock, ok := firstContent[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantFirstMsgText, firstBlock["text"]) + + // 检查注入的 ack 消息 + ackMsg, ok := messages[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "assistant", ackMsg["role"]) + + ackContent, ok := ackMsg["content"].([]any) + require.True(t, ok) + require.Len(t, ackContent, 1) + ackBlock, ok := ackContent[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantAckMsgText, ackBlock["text"]) + } + }) + } +} diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 48488dc8..97703a9d 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d0f99639..e59412eb 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -484,6 +484,7 @@ type ClaudeUsage struct { CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // ForwardResult 转发结果 @@ -504,9 +505,6 @@ type ForwardResult struct { ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" - // Sora 媒体字段 - MediaType string // image / video / prompt - MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -569,6 +567,8 @@ type GatewayService struct { responseHeaderFilter *responseheaders.CompiledHeaderFilter debugModelRouting atomic.Bool debugClaudeMimic atomic.Bool + channelService *ChannelService + resolver *ModelPricingResolver debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService } @@ -598,6 +598,8 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + channelService *ChannelService, + resolver *ModelPricingResolver, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -630,6 +632,8 @@ func NewGatewayService( modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), tlsFPProfileService: tlsFPProfileService, + channelService: channelService, + resolver: resolver, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -867,17 +871,7 @@ type anthropicMetadataPayload struct { // replaceModelInBody 替换请求体中的model字段 // 优先使用定点修改,尽量保持客户端原始字段顺序。 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - if len(body) == 0 { - return body - } - if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { - return body - } - newBody, err := sjson.SetBytes(body, "model", newModel) - if err != nil { - return body - } - return newBody + return ReplaceModelInBody(body, newModel) } type claudeOAuthNormalizeOptions struct { @@ -1180,6 +1174,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { @@ -1192,8 +1195,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { +// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 +// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID +// sub2apiUserID: 系统用户 ID,用于二维亲和调度 +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1214,6 +1219,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -1412,19 +1426,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if s.isAccountSchedulableForSelection(stickyAccount) && + var stickyCacheMissReason string + + gatePass := s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) + + if rpmPass { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 + stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { @@ -1438,27 +1457,49 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - // 会话限制已满,继续到负载感知选择 + if stickyCacheMissReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + stickyCacheMissReason = "session_limit" + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + stickyCacheMissReason = "wait_queue_full" } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } else if !gatePass { + stickyCacheMissReason = "gate_check" + } else { + stickyCacheMissReason = "rpm_red" + } + + // 记录粘性缓存未命中的结构化日志 + if stickyCacheMissReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -1914,9 +1955,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - if platform == PlatformSora { - return s.listSoraSchedulableAccounts(ctx, groupID) - } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2013,53 +2051,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } -func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { - const useMixed = false - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } else if groupID != nil { - accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "error", err) - return nil, useMixed, err - } - - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform != PlatformSora { - continue - } - if !s.isSoraAccountSchedulable(&acc) { - continue - } - filtered = append(filtered, acc) - } - slog.Debug("account_scheduling_list_sora", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "raw_count", len(accounts), - "filtered_count", len(filtered)) - for _, acc := range filtered { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) - } - return filtered, useMixed, nil -} - // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2084,33 +2075,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } -func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { - return s.soraUnschedulableReason(account) == "" -} - -func (s *GatewayService) soraUnschedulableReason(account *Account) string { - if account == nil { - return "account_nil" - } - if account.Status != StatusActive { - return fmt.Sprintf("status=%s", account.Status) - } - if !account.Schedulable { - return "schedulable=false" - } - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) - } - return "" -} - func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } - if account.Platform == PlatformSora { - return s.isSoraAccountSchedulable(account) - } return account.IsSchedulable() } @@ -2118,12 +2086,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } - if account.Platform == PlatformSora { - if !s.isSoraAccountSchedulable(account) { - return false - } - return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 - } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -2738,6 +2700,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2809,6 +2777,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2900,6 +2874,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, + // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -2911,9 +2888,18 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -2974,6 +2960,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -3041,6 +3033,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3134,6 +3132,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -3145,6 +3145,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3152,6 +3158,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -3253,9 +3262,6 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) - if platform == PlatformSora { - s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) - } return stats } @@ -3312,11 +3318,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - detail := "generic_unschedulable" - if acc.Platform == PlatformSora { - detail = s.soraUnschedulableReason(acc) - } - return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3340,57 +3342,6 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -func (s *GatewayService) logSoraSelectionFailureDetails( - ctx context.Context, - groupID *int64, - sessionHash string, - requestedModel string, - accounts []Account, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) { - const maxLines = 30 - logged := 0 - - for i := range accounts { - if logged >= maxLines { - break - } - acc := &accounts[i] - diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) - if diagnosis.Category == "eligible" { - continue - } - detail := diagnosis.Detail - if detail == "" { - detail = "-" - } - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - acc.ID, - acc.Platform, - diagnosis.Category, - detail, - ) - logged++ - } - if len(accounts) > maxLines { - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - len(accounts), - logged, - ) - } -} - func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3469,9 +3420,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } - if account.Platform == PlatformSora { - return s.isSoraModelSupportedByAccount(account, requestedModel) - } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok @@ -3484,143 +3432,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } -func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { - if account == nil { - return false - } - if strings.TrimSpace(requestedModel) == "" { - return true - } - - // 先走原始精确/通配符匹配。 - mapping := account.GetModelMapping() - if len(mapping) == 0 || account.IsModelSupported(requestedModel) { - return true - } - - aliases := buildSoraModelAliases(requestedModel) - if len(aliases) == 0 { - return false - } - - hasSoraSelector := false - for pattern := range mapping { - if !isSoraModelSelector(pattern) { - continue - } - hasSoraSelector = true - if matchPatternAnyAlias(pattern, aliases) { - return true - } - } - - // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), - // 此时不应误拦截 Sora 模型请求。 - if !hasSoraSelector { - return true - } - - return false -} - -func matchPatternAnyAlias(pattern string, aliases []string) bool { - normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) - if normalizedPattern == "" { - return false - } - for _, alias := range aliases { - if matchWildcard(normalizedPattern, alias) { - return true - } - } - return false -} - -func isSoraModelSelector(pattern string) bool { - p := strings.ToLower(strings.TrimSpace(pattern)) - if p == "" { - return false - } - - switch { - case strings.HasPrefix(p, "sora"), - strings.HasPrefix(p, "gpt-image"), - strings.HasPrefix(p, "prompt-enhance"), - strings.HasPrefix(p, "sy_"): - return true - } - - return p == "video" || p == "image" -} - -func buildSoraModelAliases(requestedModel string) []string { - modelID := strings.ToLower(strings.TrimSpace(requestedModel)) - if modelID == "" { - return nil - } - - aliases := make([]string, 0, 8) - addAlias := func(value string) { - v := strings.ToLower(strings.TrimSpace(value)) - if v == "" { - return - } - for _, existing := range aliases { - if existing == v { - return - } - } - aliases = append(aliases, v) - } - - addAlias(modelID) - cfg, ok := GetSoraModelConfig(modelID) - if ok { - addAlias(cfg.Model) - switch cfg.Type { - case "video": - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case "image": - addAlias("image") - addAlias("gpt-image") - case "prompt_enhance": - addAlias("prompt-enhance") - } - return aliases - } - - switch { - case strings.HasPrefix(modelID, "sora"): - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case strings.HasPrefix(modelID, "gpt-image"): - addAlias("image") - addAlias("gpt-image") - case strings.HasPrefix(modelID, "prompt-enhance"): - addAlias("prompt-enhance") - default: - return nil - } - - return aliases -} - -func soraVideoFamilyAlias(modelID string) string { - switch { - case strings.HasPrefix(modelID, "sora2pro-hd"): - return "sora2pro-hd" - case strings.HasPrefix(modelID, "sora2pro"): - return "sora2pro" - case strings.HasPrefix(modelID, "sora2"): - return "sora2" - default: - return "" - } -} - // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -3897,6 +3708,77 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } +// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, +// system 字段仅保留 Claude Code 标识提示词。 +// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 +// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 +// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 +func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { + system = normalizeSystemParam(system) + + // 1. 提取原始 system prompt 文本 + var originalSystemText string + switch v := system.(type) { + case string: + originalSystemText = strings.TrimSpace(v) + case []any: + var parts []string + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + } + originalSystemText = strings.Join(parts, "\n\n") + } + + // 2. 将 system 替换为 Claude Code 标准提示词(纯字符串,通过 Anthropic 检测) + out, ok := setJSONValueBytes(body, "system", claudeCodeSystemPrompt) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") + return body + } + + // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 + // 模型仍通过 messages 接收完整指令,保留客户端功能 + ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) + if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { + instrMsg, err1 := json.Marshal(map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, + }, + }) + ackMsg, err2 := json.Marshal(map[string]any{ + "role": "assistant", + "content": []map[string]any{ + {"type": "text", "text": "Understood. I will follow these instructions."}, + }, + }) + if err1 != nil || err2 != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") + return out + } + + // 重建 messages 数组:[instruction, ack, ...originalMessages] + items := [][]byte{instrMsg, ackMsg} + messagesResult := gjson.GetBytes(out, "messages") + if messagesResult.IsArray() { + messagesResult.ForEach(func(_, msg gjson.Result) bool { + items = append(items, []byte(msg.Raw)) + return true + }) + } + + if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { + out = next + } + } + + return out +} + type cacheControlPath struct { path string log string @@ -4058,7 +3940,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) if policy.blockErr != nil { return nil, policy.blockErr } @@ -4088,11 +3970,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) + body = rewriteSystemForNonClaudeCode(body, parsed.System) } normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} @@ -5755,7 +5637,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) effectiveDropSet := mergeDropSets(policyFilterSet) effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) @@ -6019,7 +5901,7 @@ type betaPolicyResult struct { } // evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { if s.settingService == nil { return betaPolicyResult{} } @@ -6034,10 +5916,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } - switch rule.Action { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + switch effectiveAction { case BetaPolicyActionBlock: if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -6079,7 +5962,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" // In the /v1/messages path, Forward() evaluates the policy first and caches the result; // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { if c != nil { if v, ok := c.Get(betaPolicyFilterSetKey); ok { if fs, ok := v.(map[string]struct{}); ok { @@ -6087,7 +5970,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } } } - return s.evaluateBetaPolicy(ctx, "", account).filterSet + return s.evaluateBetaPolicy(ctx, "", account, model).filterSet } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. @@ -6106,6 +5989,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { } } +// matchModelWhitelist checks if a model matches any pattern in the whitelist. +// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. +func matchModelWhitelist(model string, whitelist []string) bool { + for _, pattern := range whitelist { + if matchModelPattern(pattern, model) { + return true + } + } + return false +} + +// resolveRuleAction determines the effective action and error message for a rule given the request model. +// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. +// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. +func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { + if len(rule.ModelWhitelist) == 0 { + return rule.Action, rule.ErrorMessage + } + if matchModelWhitelist(model, rule.ModelWhitelist) { + return rule.Action, rule.ErrorMessage + } + if rule.FallbackAction != "" { + return rule.FallbackAction, rule.FallbackErrorMessage + } + return BetaPolicyActionPass, "" // default fallback: pass (fail-open) +} + // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -6152,7 +6062,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( modelID string, ) ([]string, error) { // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) if policy.blockErr != nil { return nil, policy.blockErr } @@ -6164,7 +6074,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { return nil, blockErr } @@ -6173,7 +6083,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { if s.settingService == nil || len(tokens) == 0 { return nil } @@ -6185,14 +6095,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { - if rule.Action != BetaPolicyActionBlock { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + if effectiveAction != BetaPolicyActionBlock { continue } if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -7405,6 +7316,8 @@ type RecordUsageInput struct { RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -7434,6 +7347,18 @@ type postUsageBillingParams struct { APIKeyService APIKeyQuotaUpdater } +func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { + return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateRateLimits() bool { + return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { + return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() +} + // postUsageBilling 统一处理使用量记录后的扣费逻辑: // - 订阅/余额扣费 // - API Key 配额更新 @@ -7463,21 +7388,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } // 2. API Key 配额 - if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 - if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) - if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) @@ -7538,9 +7463,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount - if usageLog.MediaType != nil { - cmd.MediaType = *usageLog.MediaType - } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7559,13 +7481,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.BalanceCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { cmd.APIKeyQuotaCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { cmd.APIKeyRateLimitCost = p.Cost.ActualCost } - if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier } @@ -7689,191 +7611,39 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage } } +// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 +type recordUsageOpts struct { + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + ParsedRequest *ParsedRequest + + // EnableClaudePath 启用 Claude 路径特有逻辑: + // - Claude Max 缓存计费策略 + EnableClaudePath bool + + // 长上下文计费(仅 Gemini 路径需要) + LongContextThreshold int + LongContextMultiplier float64 +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { - result := input.Result - apiKey := input.APIKey - user := input.User - account := input.Account - subscription := input.Subscription - - // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens - // 用于粘性会话切换时的特殊计费处理 - if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", - result.Usage.InputTokens, account.ID) - result.Usage.CacheReadInputTokens += result.Usage.InputTokens - result.Usage.InputTokens = 0 - } - - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { - applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) - cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 - } - - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := 1.0 - if s.cfg != nil { - multiplier = s.cfg.Default.RateMultiplier - } - if apiKey.GroupID != nil && apiKey.Group != nil { - groupDefault := apiKey.Group.RateMultiplier - multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) - } - - var cost *CostBreakdown - billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) - - // 根据请求类型选择计费方式 - if result.MediaType == "image" || result.MediaType == "video" { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == "image" { - cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } else { - cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) - } - } else if result.MediaType == "prompt" { - cost = &CostBreakdown{} - } else if result.ImageCount > 0 { - // 图片生成计费 - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) - } else { - // Token 计费 - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - } - var err error - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } - } - - // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := BillingTypeBalance - if isSubscriptionBilling { - billingType = BillingTypeSubscription - } - - // 创建使用日志 - durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } - var mediaType *string - if strings.TrimSpace(result.MediaType) != "" { - mediaType = &result.MediaType - } - accountRateMultiplier := account.BillingRateMultiplier() - requestID := resolveUsageBillingRequestID(ctx, result.RequestID) - usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: result.Model, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - Stream: result.Stream, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - ImageCount: result.ImageCount, - ImageSize: imageSize, - MediaType: mediaType, - CacheTTLOverridden: cacheTTLOverridden, - CreatedAt: time.Now(), - } - - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent - } - - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress - } - - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID - } - if subscription != nil { - usageLog.SubscriptionID = &subscription.ID - } - - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil - } - - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() - - if billingErr != nil { - return billingErr - } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - - return nil + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + EnableClaudePath: true, + }) } // RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) @@ -7892,10 +7662,54 @@ type RecordUsageLongContextInput struct { LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + LongContextThreshold: input.LongContextThreshold, + LongContextMultiplier: input.LongContextMultiplier, + }) +} + +// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。 +type recordUsageCoreInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string + IPAddress string + RequestPayloadHash string + ForceCacheBilling bool + APIKeyService APIKeyQuotaUpdater + ChannelUsageFields +} + +// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 +// opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext +func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result apiKey := input.APIKey user := input.User @@ -7928,38 +7742,23 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } - var cost *CostBreakdown + // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) - - // 根据请求类型选择计费方式 - if result.ImageCount > 0 { - // 图片生成计费 - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) - } else { - // Token 计费(使用长上下文计费方法) - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - } - var err error - cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { + billingModel = input.ChannelMappedModel } + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } + + // 计算费用 + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() @@ -7969,12 +7768,182 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 创建使用日志 - durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } accountRateMultiplier := account.BillingRateMultiplier() + usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, + requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + requestID := usageLog.RequestID + _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + + return nil +} + +// calculateRecordUsageCost 根据请求类型和选项计算费用。 +func (s *GatewayService) calculateRecordUsageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + // 图片生成计费 + if result.ImageCount > 0 { + return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) + } + + // Token 计费 + return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) +} + +// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 +// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 +func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + +// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 +func (s *GatewayService) calculateImageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost + } + + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 +func (s *GatewayService) calculateTokenCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + + var cost *CostBreakdown + var err error + + // 优先尝试渠道定价 → CalculateCostUnified + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + } else if opts.LongContextThreshold > 0 { + // 长上下文双倍计费(如 Gemini 200K 阈值) + cost, err = s.billingService.CalculateCostWithLongContext( + billingModel, tokens, multiplier, + opts.LongContextThreshold, opts.LongContextMultiplier, + ) + } else { + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + } + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost +} + +// buildRecordUsageLog 构建使用日志并设置计费模式。 +func (s *GatewayService) buildRecordUsageLog( + ctx context.Context, + input *recordUsageCoreInput, + result *ForwardResult, + apiKey *APIKey, + user *User, + account *Account, + subscription *UserSubscription, + requestedModel string, + multiplier float64, + accountRateMultiplier float64, + billingType int8, + cacheTTLOverridden bool, + cost *CostBreakdown, + opts *recordUsageOpts, +) *UsageLog { + durationMs := int(result.Duration.Milliseconds()) requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, @@ -7982,7 +7951,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, - RequestedModel: result.Model, + RequestedModel: requestedModel, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -7993,72 +7962,156 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheReadTokens: result.Usage.CacheReadInputTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, + ImageOutputTokens: result.Usage.ImageOutputTokens, RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, + BillingMode: resolveBillingMode(result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, - ImageSize: imageSize, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), CacheTTLOverridden: cacheTTLOverridden, + ChannelID: optionalInt64Ptr(input.ChannelID), + ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), + UserAgent: optionalTrimmedStringPtr(input.UserAgent), + IPAddress: optionalTrimmedStringPtr(input.IPAddress), + GroupID: apiKey.GroupID, + SubscriptionID: optionalSubscriptionID(subscription), CreatedAt: time.Now(), } - - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent + if cost != nil { + usageLog.InputCost = cost.InputCost + usageLog.OutputCost = cost.OutputCost + usageLog.ImageOutputCost = cost.ImageOutputCost + usageLog.CacheCreationCost = cost.CacheCreationCost + usageLog.CacheReadCost = cost.CacheReadCost + usageLog.TotalCost = cost.TotalCost + usageLog.ActualCost = cost.ActualCost } - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress - } + return usageLog +} - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID +// resolveBillingMode 根据计费结果和请求类型确定计费模式。 +func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { + var mode string + switch { + case cost != nil && cost.BillingMode != "": + mode = cost.BillingMode + case result.ImageCount > 0: + mode = string(BillingModeImage) + default: + mode = string(BillingModeToken) } + return &mode +} + +func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { - usageLog.SubscriptionID = &subscription.ID + return &subscription.ID } - - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil - } - - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() - - if billingErr != nil { - return billingErr - } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - return nil } +// ResolveChannelMapping 委托渠道服务解析模型映射 +func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) +func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} + +// IsModelRestricted 检查模型是否被渠道限制 +func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 +func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) +} + +// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。 +// 供调度阶段预检查(requested / channel_mapped)。 +// upstream 需逐账号检查,此处返回 false。 +func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +// billingModelForRestriction 根据计费基准确定限制检查使用的模型。 +// upstream 返回空(需逐账号检查)。 +func billingModelForRestriction(source, requestedModel, channelMappedModel string) string { + switch source { + case BillingModelSourceRequested: + return requestedModel + case BillingModelSourceUpstream: + return "" + case BillingModelSourceChannelMapped: + return channelMappedModel + default: + return channelMappedModel + } +} + +// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。 +// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。 +func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveAccountUpstreamModel(account, requestedModel) + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。 +func resolveAccountUpstreamModel(account *Account, requestedModel string) string { + if account.Platform == PlatformAntigravity { + return mapAntigravityModel(account, requestedModel) + } + return account.GetMappedModel(requestedModel) +} + +// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。 +func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { @@ -8508,7 +8561,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go index 743d70bb..ac8c6df6 100644 --- a/backend/internal/service/gateway_service_selection_failure_stats_test.go +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -9,35 +9,35 @@ import ( func TestCollectSelectionFailureStats(t *testing.T) { svc := &GatewayService{} - model := "sora2-landscape-10s" + model := "gpt-5.4" resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) accounts := []Account{ // excluded { ID: 1, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, }, // unschedulable { ID: 2, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: false, }, // platform filtered { ID: 3, - Platform: PlatformOpenAI, + Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true, }, // model unsupported { ID: 4, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Credentials: map[string]any{ @@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) { // model rate limited { ID: 5, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Extra: map[string]any{ @@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) { // eligible { ID: 6, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, }, } excluded := map[int64]struct{}{1: {}} - stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false) if stats.Total != 6 { t.Fatalf("total=%d want=6", stats.Total) @@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) { } } -func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { +func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) { svc := &GatewayService{} acc := &Account{ ID: 7, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: false, } - diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false) if diagnosis.Category != "unschedulable" { t.Fatalf("category=%s want=unschedulable", diagnosis.Category) } - if diagnosis.Detail != "schedulable=false" { - t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + if diagnosis.Detail != "generic_unschedulable" { + t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail) } } -func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { +func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) { svc := &GatewayService{} - model := "sora2-landscape-10s" + model := "gpt-5.4" resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) acc := &Account{ ID: 8, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Extra: map[string]any{ @@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { }, } - diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false) if diagnosis.Category != "model_rate_limited" { t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) } diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go deleted file mode 100644 index 8ee2a960..00000000 --- a/backend/internal/service/gateway_service_sora_model_support_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package service - -import "testing" - -func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{}, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected sora model to be supported when model_mapping is empty") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "gpt-4o": "gpt-4o", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected sora model to be supported when mapping has no sora selectors") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "sora2": "sora2", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { - t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "sy_8": "sy_8", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "gpt-image": "gpt-image", - }, - }, - } - - if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") - } -} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go deleted file mode 100644 index 5178e68e..00000000 --- a/backend/internal/service/gateway_service_sora_scheduling_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" -) - -func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { - svc := &GatewayService{} - now := time.Now() - past := now.Add(-1 * time.Minute) - future := now.Add(5 * time.Minute) - - acc := &Account{ - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - AutoPauseOnExpired: true, - ExpiresAt: &past, - OverloadUntil: &future, - RateLimitResetAt: &future, - } - - if !svc.isAccountSchedulableForSelection(acc) { - t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") - } -} - -func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { - svc := &GatewayService{} - future := time.Now().Add(5 * time.Minute) - - acc := &Account{ - Platform: PlatformAnthropic, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &future, - } - - if svc.isAccountSchedulableForSelection(acc) { - t.Fatalf("expected non-sora account to keep generic schedulable checks") - } -} - -func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { - svc := &GatewayService{} - model := "sora2-landscape-10s" - resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) - globalResetAt := time.Now().Add(2 * time.Minute) - - acc := &Account{ - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &globalResetAt, - Extra: map[string]any{ - "model_rate_limits": map[string]any{ - model: map[string]any{ - "rate_limit_reset_at": resetAt, - }, - }, - }, - } - - if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { - t.Fatalf("expected sora account to be blocked by model scope rate limit") - } -} - -func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { - svc := &GatewayService{} - future := time.Now().Add(3 * time.Minute) - - accounts := []Account{ - { - ID: 1, - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &future, - }, - } - - stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) - if stats.Unschedulable != 0 || stats.Eligible != 1 { - t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) - } -} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 5b1abc11..b35ebce5 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage { cand := int(usage.Get("candidatesTokenCount").Int()) cached := int(usage.Get("cachedContentTokenCount").Int()) thoughts := int(usage.Get("thoughtsTokenCount").Int()) + + // 从 candidatesTokensDetails 提取 IMAGE 模态 token 数 + imageTokens := 0 + candidateDetails := usage.Get("candidatesTokensDetails") + if candidateDetails.Exists() { + candidateDetails.ForEach(func(_, detail gjson.Result) bool { + if detail.Get("modality").String() == "IMAGE" { + imageTokens = int(detail.Get("tokenCount").Int()) + return false + } + return true + }) + } + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ InputTokens: prompt - cached, OutputTokens: cand + thoughts, CacheReadInputTokens: cached, + ImageOutputTokens: imageTokens, } } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e17032e0..d59af9e1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,15 +26,6 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 - // Sora 按次计费配置(阶段 1) - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - - // Sora 存储配额 - SoraStorageQuotaBytes int64 - // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -59,6 +50,8 @@ type Group struct { // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) DefaultMappedModel string CreatedAt time.Time @@ -110,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } -// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) -func (g *Group) GetSoraImagePrice(imageSize string) *float64 { - switch imageSize { - case "360": - return g.SoraImagePrice360 - case "540": - return g.SoraImagePrice540 - default: - return g.SoraImagePrice360 - } -} - // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go new file mode 100644 index 00000000..b7ca4cb7 --- /dev/null +++ b/backend/internal/service/model_pricing_resolver.go @@ -0,0 +1,231 @@ +package service + +import ( + "context" + "log/slog" +) + +// PricingSource 定价来源标识 +const ( + PricingSourceChannel = "channel" + PricingSourceLiteLLM = "litellm" + PricingSourceFallback = "fallback" +) + +// ResolvedPricing 统一定价解析结果 +type ResolvedPricing struct { + // Mode 计费模式 + Mode BillingMode + + // Token 模式:基础定价(来自 LiteLLM 或 fallback) + BasePricing *ModelPricing + + // Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段) + Intervals []PricingInterval + + // 按次/图片模式:分层定价 + RequestTiers []PricingInterval + + // 按次/图片模式:默认价格(未命中层级时使用) + DefaultPerRequestPrice float64 + + // 来源标识 + Source string // "channel", "litellm", "fallback" + + // 是否支持缓存细分 + SupportsCacheBreakdown bool +} + +// ModelPricingResolver 统一模型定价解析器。 +// 解析链:Channel → LiteLLM → Fallback。 +type ModelPricingResolver struct { + channelService *ChannelService + billingService *BillingService +} + +// NewModelPricingResolver 创建定价解析器实例 +func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver { + return &ModelPricingResolver{ + channelService: channelService, + billingService: billingService, + } +} + +// PricingInput 定价解析输入 +type PricingInput struct { + Model string + GroupID *int64 // nil 表示不检查渠道 +} + +// Resolve 解析模型定价。 +// 1. 获取基础定价(LiteLLM → Fallback) +// 2. 如果指定了 GroupID,查找渠道定价并覆盖 +func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing { + // 1. 获取基础定价 + basePricing, source := r.resolveBasePricing(input.Model) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Source: source, + SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown, + } + + // 2. 如果有 GroupID,尝试渠道覆盖 + if input.GroupID != nil { + r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved) + } + + return resolved +} + +// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价 +func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) { + pricing, err := r.billingService.GetModelPricing(model) + if err != nil { + slog.Debug("failed to get model pricing from LiteLLM, using fallback", + "model", model, "error", err) + return nil, PricingSourceFallback + } + return pricing, PricingSourceLiteLLM +} + +// applyChannelOverrides 应用渠道定价覆盖 +func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) { + chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model) + if chPricing == nil { + return + } + + resolved.Source = PricingSourceChannel + resolved.Mode = chPricing.BillingMode + if resolved.Mode == "" { + resolved.Mode = BillingModeToken + } + + switch resolved.Mode { + case BillingModeToken: + r.applyTokenOverrides(chPricing, resolved) + case BillingModePerRequest, BillingModeImage: + r.applyRequestTierOverrides(chPricing, resolved) + } +} + +// applyTokenOverrides 应用 token 模式的渠道覆盖 +func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + // 过滤掉所有价格字段都为空的无效 interval + validIntervals := filterValidIntervals(chPricing.Intervals) + + // 如果有有效的区间定价,使用区间 + if len(validIntervals) > 0 { + resolved.Intervals = validIntervals + return + } + + // 否则用 flat 字段覆盖 BasePricing + if resolved.BasePricing == nil { + resolved.BasePricing = &ModelPricing{} + } + + if chPricing.InputPrice != nil { + resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice + resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice + } + if chPricing.OutputPrice != nil { + resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice + resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice + } + if chPricing.CacheWritePrice != nil { + resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice + } + if chPricing.CacheReadPrice != nil { + resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice + resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice + } + if chPricing.ImageOutputPrice != nil { + resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice + } +} + +// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖 +func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + resolved.RequestTiers = filterValidIntervals(chPricing.Intervals) + if chPricing.PerRequestPrice != nil { + resolved.DefaultPerRequestPrice = *chPricing.PerRequestPrice + } +} + +// filterValidIntervals 过滤掉所有价格字段都为空的无效 interval。 +// 前端可能创建了只有 min/max 但无价格的空 interval。 +func filterValidIntervals(intervals []PricingInterval) []PricingInterval { + var valid []PricingInterval + for _, iv := range intervals { + if iv.InputPrice != nil || iv.OutputPrice != nil || + iv.CacheWritePrice != nil || iv.CacheReadPrice != nil || + iv.PerRequestPrice != nil { + valid = append(valid, iv) + } + } + return valid +} + +// GetIntervalPricing 根据 context token 数获取区间定价。 +// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。 +func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing { + if len(resolved.Intervals) == 0 { + return resolved.BasePricing + } + + iv := FindMatchingInterval(resolved.Intervals, totalContextTokens) + if iv == nil { + return resolved.BasePricing + } + + return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown) +} + +// intervalToModelPricing 将区间定价转换为 ModelPricing +func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing { + pricing := &ModelPricing{ + SupportsCacheBreakdown: supportsCacheBreakdown, + } + if iv.InputPrice != nil { + pricing.InputPricePerToken = *iv.InputPrice + pricing.InputPricePerTokenPriority = *iv.InputPrice + } + if iv.OutputPrice != nil { + pricing.OutputPricePerToken = *iv.OutputPrice + pricing.OutputPricePerTokenPriority = *iv.OutputPrice + } + if iv.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *iv.CacheWritePrice + pricing.CacheCreation5mPrice = *iv.CacheWritePrice + pricing.CacheCreation1hPrice = *iv.CacheWritePrice + } + if iv.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *iv.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice + } + return pricing +} + +// GetRequestTierPrice 根据层级标签获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 { + for _, tier := range resolved.RequestTiers { + if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil { + return *tier.PerRequestPrice + } + } + return 0 +} + +// GetRequestTierPriceByContext 根据 context token 数获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 { + iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens) + if iv != nil && iv.PerRequestPrice != nil { + return *iv.PerRequestPrice + } + return 0 +} diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go new file mode 100644 index 00000000..905c4df6 --- /dev/null +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -0,0 +1,663 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTestBillingServiceForResolver() *BillingService { + bs := &BillingService{ + fallbackPrices: make(map[string]*ModelPricing), + } + bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{ + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + CacheCreationPricePerToken: 3.75e-6, + CacheReadPricePerToken: 0.3e-6, + SupportsCacheBreakdown: false, + } + return bs +} + +func TestResolve_NoGroupID(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // BillingService.GetModelPricing uses fallback internally, but resolveBasePricing + // reports "litellm" when GetModelPricing succeeds (regardless of internal source) + require.Equal(t, "litellm", resolved.Source) +} + +func TestResolve_UnknownModel(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Nil(t, resolved.BasePricing) + // Unknown model: GetModelPricing returns error, source is "fallback" + require.Equal(t, "fallback", resolved.Source) +} + +func TestGetIntervalPricing_NoIntervals(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 5e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: nil, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.Equal(t, basePricing, result) +} + +func TestGetIntervalPricing_MatchesInterval(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: &ModelPricing{InputPricePerToken: 5e-6}, + SupportsCacheBreakdown: true, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, result) + require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12) + require.True(t, result.SupportsCacheBreakdown) + + result2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, result2) + require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 99e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 5000) + require.Equal(t, basePricing, result) +} + +func TestGetRequestTierPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + }, + } + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12) +} + +func TestGetRequestTierPriceByContext(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + } + + require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: nil}, + }, + } + + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12) +} + +// =========================================================================== +// Channel override tests — exercises applyChannelOverrides via Resolve +// =========================================================================== + +// helper: creates a resolver wired to a ChannelService that returns the given +// channel (active, groupID=100, platform=anthropic) with the specified pricing. +func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver { + t.Helper() + const groupID = 100 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{groupID}, + ModelPricing: pricing, + }}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{groupID: "anthropic"}, nil + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + return NewModelPricingResolver(cs, bs) +} + +// groupIDPtr returns a pointer to groupID 100 (the test constant). +func groupIDPtr() *int64 { v := int64(100); return &v } + +// --------------------------------------------------------------------------- +// 1. Token mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(50e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) { + // Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback). + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(20e-6), + // OutputPrice intentionally nil + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + // InputPrice overridden by channel + require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // OutputPrice kept from base (fallback: 15e-6) + require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.Len(t, resolved.Intervals, 2) + + // GetIntervalPricing should use channel intervals + iv := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, iv) + require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12) + require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12) + + iv2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, iv2) + require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12) + require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) { + // Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"unknown-model-xyz"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(7e-6), + OutputPrice: testPtrFloat64(21e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + // BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 2. Per-request mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_PerRequest(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 2) + + // Verify tier lookups + require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) { + // PerRequestPrice nil → DefaultPerRequestPrice stays 0. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + // PerRequestPrice intentionally nil + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 1) +} + +// --------------------------------------------------------------------------- +// 3. Image mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_Image(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.08), + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeImage, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 3) +} + +func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found +} + +// --------------------------------------------------------------------------- +// 4. Source tracking & default mode +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(1e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) +} + +func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) { + // Channel pricing with empty BillingMode → defaults to BillingModeToken. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: "", // intentionally empty + InputPrice: testPtrFloat64(5e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 5. GetIntervalPricing integration after channel override +// --------------------------------------------------------------------------- + +func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) { + // Channel provides intervals that override the base pricing path. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)}, + {MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 50000 matches first interval + pricing := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, pricing) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12) + + // Token count 150000 matches second interval + pricing2 := r.GetIntervalPricing(resolved, 150000) + require.NotNil(t, pricing2) + require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) { + // Channel intervals don't match token count → falls back to BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + // Only covers tokens > 50000 + {MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 1000 doesn't match any interval (1000 <= 50000 minTokens) + pricing := r.GetIntervalPricing(resolved, 1000) + // Should fall back to BasePricing (from the billing service fallback) + require.NotNil(t, pricing) + require.Equal(t, resolved.BasePricing, pricing) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price +} + +// =========================================================================== +// 6. Error path tests +// =========================================================================== + +func TestResolve_WithChannelOverride_CacheError(t *testing.T) { + // When ListAll returns an error, the ChannelService cache build fails. + // Resolve should gracefully fall back to base pricing without panicking. + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("database unavailable") + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(cs, bs) + + gid := int64(100) + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: &gid, + }) + + require.NotNil(t, resolved) + // Should NOT panic, should NOT have source "channel" + require.NotEqual(t, "channel", resolved.Source) + // Base pricing should still be present (from BillingService fallback) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// =========================================================================== +// 7. GetRequestTierPriceByContext boundary tests +// =========================================================================== + +func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: nil, // empty + } + + price := r.GetRequestTierPriceByContext(resolved, 50000) + require.InDelta(t, 0.0, price, 1e-12) + + // Also test with explicit empty slice + resolved2 := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{}, + } + + price2 := r.GetRequestTierPriceByContext(resolved2, 50000) + require.InDelta(t, 0.0, price2, 1e-12) +} + +func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + } + + // totalContextTokens = 128000 exactly: + // FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens + // For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval + price := r.GetRequestTierPriceByContext(resolved, 128000) + require.InDelta(t, 0.05, price, 1e-12) + + // totalContextTokens = 128001 should match second interval + // For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match + // For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches + price2 := r.GetRequestTierPriceByContext(resolved, 128001) + require.InDelta(t, 0.10, price2, 1e-12) +} + +// =========================================================================== +// 8. filterValidIntervals +// =========================================================================== + +func TestFilterValidIntervals(t *testing.T) { + tests := []struct { + name string + intervals []PricingInterval + wantLen int + }{ + { + name: "empty list", + intervals: nil, + wantLen: 0, + }, + { + name: "all-nil interval filtered out", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000)}, + }, + wantLen: 0, + }, + { + name: "interval with only InputPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only OutputPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), OutputPrice: testPtrFloat64(2e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only CacheWritePrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, CacheWritePrice: testPtrFloat64(3e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only CacheReadPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, CacheReadPrice: testPtrFloat64(0.5e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only PerRequestPrice kept", + intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + }, + wantLen: 1, + }, + { + name: "mixed valid and invalid", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil}, // all-nil → filtered out + {MinTokens: 256000, OutputPrice: testPtrFloat64(5e-6)}, + }, + wantLen: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterValidIntervals(tt.intervals) + require.Len(t, result, tt.wantLen) + }) + } +} diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 5dbba638..571e9ecd 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "strconv" + "strings" + "sync" "time" ) @@ -17,7 +19,7 @@ type OAuthRefreshExecutor interface { CacheKey(account *Account) string } -const refreshLockTTL = 30 * time.Second +const defaultRefreshLockTTL = 60 * time.Second // OAuthRefreshResult 统一刷新结果 type OAuthRefreshResult struct { @@ -28,20 +30,39 @@ type OAuthRefreshResult struct { } // OAuthRefreshAPI 统一的 OAuth Token 刷新入口 -// 封装分布式锁、DB 重读、已刷新检查等通用逻辑 +// 封装分布式锁、进程内互斥锁、DB 重读、已刷新检查、竞争恢复等通用逻辑 type OAuthRefreshAPI struct { accountRepo AccountRepository - tokenCache GeminiTokenCache // 可选,nil = 无锁 + tokenCache GeminiTokenCache // 可选,nil = 无分布式锁 + lockTTL time.Duration + localLocks sync.Map // key: cacheKey string -> value: *sync.Mutex } // NewOAuthRefreshAPI 创建统一刷新 API -func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { +// 可选传入 lockTTL 覆盖默认的 60s 分布式锁 TTL +func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache, lockTTL ...time.Duration) *OAuthRefreshAPI { + ttl := defaultRefreshLockTTL + if len(lockTTL) > 0 && lockTTL[0] > 0 { + ttl = lockTTL[0] + } return &OAuthRefreshAPI{ accountRepo: accountRepo, tokenCache: tokenCache, + lockTTL: ttl, } } +// getLocalLock 返回指定 cacheKey 的进程内互斥锁 +func (api *OAuthRefreshAPI) getLocalLock(cacheKey string) *sync.Mutex { + actual, _ := api.localLocks.LoadOrStore(cacheKey, &sync.Mutex{}) + mu, ok := actual.(*sync.Mutex) + if !ok { + mu = &sync.Mutex{} + api.localLocks.Store(cacheKey, mu) + } + return mu +} + // RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token // // 流程: @@ -59,12 +80,17 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( ) (*OAuthRefreshResult, error) { cacheKey := executor.CacheKey(account) + // 0. 获取进程内互斥锁(防止同一进程内的并发刷新竞争) + localMu := api.getLocalLock(cacheKey) + localMu.Lock() + defer localMu.Unlock() + // 1. 获取分布式锁 lockAcquired := false if api.tokenCache != nil { - acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL) + acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, api.lockTTL) if lockErr != nil { - // Redis 错误,降级为无锁刷新 + // Redis 错误,降级为无锁刷新(进程内互斥锁仍生效) slog.Warn("oauth_refresh_lock_failed_degraded", "account_id", account.ID, "cache_key", cacheKey, @@ -102,6 +128,19 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 4. 执行平台特定刷新逻辑 newCredentials, refreshErr := executor.Refresh(ctx, freshAccount) if refreshErr != nil { + // 竞争恢复:invalid_grant 可能是另一个 worker 已消费了旧 refresh_token + // 重新读取 DB,如果 refresh_token 已更新则说明是竞争,返回成功 + if isInvalidGrantError(refreshErr) { + if recoveredAccount, recovered := api.tryRecoverFromRefreshRace(ctx, freshAccount); recovered { + slog.Info("oauth_refresh_race_recovered", + "account_id", freshAccount.ID, + "platform", freshAccount.Platform, + ) + return &OAuthRefreshResult{ + Account: recoveredAccount, + }, nil + } + } return nil, refreshErr } @@ -126,6 +165,33 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( }, nil } +// isInvalidGrantError 检查错误是否为 invalid_grant +func isInvalidGrantError(err error) bool { + return err != nil && strings.Contains(strings.ToLower(err.Error()), "invalid_grant") +} + +// tryRecoverFromRefreshRace 在 invalid_grant 错误后尝试竞争恢复 +// 重新读取 DB,如果 refresh_token 已改变(说明另一个 worker 成功刷新),则返回更新后的 account +func (api *OAuthRefreshAPI) tryRecoverFromRefreshRace(ctx context.Context, usedAccount *Account) (*Account, bool) { + if api.accountRepo == nil { + return nil, false + } + reReadAccount, err := api.accountRepo.GetByID(ctx, usedAccount.ID) + if err != nil || reReadAccount == nil { + return nil, false + } + usedRT := usedAccount.GetCredential("refresh_token") + currentRT := reReadAccount.GetCredential("refresh_token") + if usedRT == "" || currentRT == "" { + return nil, false + } + // refresh_token 不同 → 另一个 worker 已成功刷新 + if usedRT != currentRT { + return reReadAccount, true + } + return nil, false +} + // MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中 func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any { if newCreds == nil { diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index c3b38ddf..4a60723b 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -5,6 +5,7 @@ package service import ( "context" "errors" + "sync" "testing" "time" @@ -385,6 +386,224 @@ func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) { require.False(t, hasScope, "scope should not be set when empty") } +// refreshAPIAccountRepoWithRace supports returning a different account on subsequent GetByID calls +// to simulate race conditions where another worker has refreshed the token. +type refreshAPIAccountRepoWithRace struct { + refreshAPIAccountRepo + raceAccount *Account // returned on 2nd+ GetByID call + getByIDCalls int +} + +func (r *refreshAPIAccountRepoWithRace) GetByID(_ context.Context, _ int64) (*Account, error) { + r.getByIDCalls++ + if r.getByIDCalls > 1 && r.raceAccount != nil { + return r.raceAccount, nil + } + if r.getByIDErr != nil { + return nil, r.getByIDErr + } + return r.account, nil +} + +// ========== Race recovery tests ========== + +func TestRefreshIfNeeded_InvalidGrantRaceRecovered(t *testing.T) { + // Account with old refresh token + account := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt", "access_token": "old-at"}, + } + // After race, DB has new refresh token from another worker + racedAccount := &Account{ + ID: 10, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: racedAccount, + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token not found or invalid"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err, "race-recovered invalid_grant should not return error") + require.False(t, result.Refreshed) + require.False(t, result.LockHeld) + require.NotNil(t, result.Account) + require.Equal(t, "new-rt", result.Account.GetCredential("refresh_token")) + require.Equal(t, 0, repo.updateCalls) // no DB update needed, another worker did it +} + +func TestRefreshIfNeeded_InvalidGrantGenuine(t *testing.T) { + // Account with revoked refresh token - DB still has the same token + account := &Account{ + ID: 11, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "revoked-rt", "access_token": "old-at"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: account, // same refresh_token on re-read + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant: refresh token revoked"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "genuine invalid_grant should propagate error") + require.Nil(t, result) + require.Contains(t, err.Error(), "invalid_grant") +} + +func TestRefreshIfNeeded_InvalidGrantDBRereadFailsOnRecovery(t *testing.T) { + account := &Account{ + ID: 12, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + } + repo := &refreshAPIAccountRepoWithRace{ + refreshAPIAccountRepo: refreshAPIAccountRepo{account: account}, + raceAccount: nil, // GetByID returns nil on recovery attempt + } + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + err: errors.New("invalid_grant"), + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.Error(t, err, "should propagate error when recovery DB re-read fails") + require.Nil(t, result) +} + +func TestRefreshIfNeeded_LocalMutexSerializesConcurrent(t *testing.T) { + // Test that two goroutines for the same account are serialized by the local mutex. + // The first goroutine refreshes successfully; the second sees NeedsRefresh=false. + refreshed := &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "new-rt", "access_token": "new-at"}, + } + callCount := 0 + repo := &refreshAPIAccountRepo{account: &Account{ + ID: 20, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{"refresh_token": "old-rt"}, + }} + + // After first refresh, NeedsRefresh should return false + // We simulate this by using an executor that decrements needsRefresh after first call + var mu sync.Mutex + dynamicExecutor := &dynamicRefreshExecutor{ + canRefresh: true, + cacheKey: "test:mutex:anthropic", + refreshFunc: func(_ context.Context, _ *Account) (map[string]any, error) { + mu.Lock() + callCount++ + mu.Unlock() + time.Sleep(50 * time.Millisecond) // slow refresh + return map[string]any{"access_token": "new-at"}, nil + }, + needsRefreshFunc: func() bool { + mu.Lock() + defer mu.Unlock() + return callCount == 0 // only first call needs refresh + }, + } + + _ = refreshed + + api := NewOAuthRefreshAPI(repo, nil) // no distributed lock, only local mutex + + var wg sync.WaitGroup + results := make([]*OAuthRefreshResult, 2) + errs := make([]error, 2) + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx], errs[idx] = api.RefreshIfNeeded(context.Background(), repo.account, dynamicExecutor, 3*time.Minute) + }(i) + } + wg.Wait() + + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + + // Only one goroutine should have actually called Refresh + mu.Lock() + require.Equal(t, 1, callCount, "only one refresh call should have been made") + mu.Unlock() +} + +// dynamicRefreshExecutor is a test helper with function-based NeedsRefresh and Refresh. +type dynamicRefreshExecutor struct { + canRefresh bool + cacheKey string + needsRefreshFunc func() bool + refreshFunc func(context.Context, *Account) (map[string]any, error) +} + +func (e *dynamicRefreshExecutor) CanRefresh(_ *Account) bool { return e.canRefresh } + +func (e *dynamicRefreshExecutor) NeedsRefresh(_ *Account, _ time.Duration) bool { + return e.needsRefreshFunc() +} + +func (e *dynamicRefreshExecutor) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + return e.refreshFunc(ctx, account) +} + +func (e *dynamicRefreshExecutor) CacheKey(_ *Account) string { + return e.cacheKey +} + +// ========== NewOAuthRefreshAPI TTL tests ========== + +func TestNewOAuthRefreshAPI_DefaultTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_CustomTTL(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 90*time.Second) + require.Equal(t, 90*time.Second, api.lockTTL) +} + +func TestNewOAuthRefreshAPI_ZeroTTLUsesDefault(t *testing.T) { + api := NewOAuthRefreshAPI(nil, nil, 0) + require.Equal(t, defaultRefreshLockTTL, api.lockTTL) +} + +// ========== isInvalidGrantError tests ========== + +func TestIsInvalidGrantError(t *testing.T) { + require.True(t, isInvalidGrantError(errors.New("invalid_grant: token revoked"))) + require.True(t, isInvalidGrantError(errors.New("INVALID_GRANT"))) + require.False(t, isInvalidGrantError(errors.New("invalid_client"))) + require.False(t, isInvalidGrantError(nil)) +} + // ========== BackgroundRefreshPolicy tests ========== func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) { diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 37e7ed2c..6c09e354 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -4,6 +4,7 @@ import ( "container/heap" "context" "errors" + "fmt" "hash/fnv" "math" "sort" @@ -575,6 +576,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( return nil, 0, 0, 0, errors.New("no available OpenAI accounts") } + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if req.GroupID != nil && s.service.schedulerSnapshot != nil { + schedGroup, _ = s.service.schedulerSnapshot.GetGroupByID(ctx, *req.GroupID) + } + filtered := make([]*Account, 0, len(accounts)) loadReq := make([]AccountWithConcurrency, 0, len(accounts)) for i := range accounts { @@ -587,6 +594,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if !account.IsSchedulable() || !account.IsOpenAI() { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() { + _ = s.service.accountRepo.SetError(ctx, account.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { continue } diff --git a/backend/internal/service/openai_channel_restriction_test.go b/backend/internal/service/openai_channel_restriction_test.go new file mode 100644 index 00000000..c9dbceab --- /dev/null +++ b/backend/internal/service/openai_channel_restriction_test.go @@ -0,0 +1,140 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + PlatformOpenAI: {"gpt-4.1": "o3-mini"}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true}, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + _, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.ErrorIs(t, err, ErrNoAvailableAccounts) + require.Contains(t, err.Error(), "channel pricing restriction") +} + +func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) +} + +func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:sticky-session": 1}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + cache: cache, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) + require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"]) + require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"]) +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index d0534d8c..4ec038e0 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if v, ok := reqBody["model"].(string); ok { model = v } - normalizedModel := normalizeCodexModel(model) + normalizedModel := strings.TrimSpace(model) if normalizedModel != "" { if model != normalizedModel { reqBody["model"] = normalizedModel @@ -275,6 +275,13 @@ func normalizeCodexModel(model string) string { return "gpt-5.1" } +func normalizeOpenAIModelForUpstream(account *Account, model string) string { + if account == nil || account.Type == AccountTypeOAuth { + return normalizeCodexModel(model) + } + return strings.TrimSpace(model) +} + func SupportsVerbosity(model string) bool { if !strings.HasPrefix(model, "gpt-") { return true diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index eab88c09..889ac615 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt 5.3 codex spark": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex", @@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) +} + +func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) { + reqBody := map[string]any{ + "model": " gpt-5.3-codex-spark ", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { // Codex CLI 场景:已有 instructions 时不修改 diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go index eb9148de..6ca3e85c 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key_test.go +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) } @@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") require.NotEqual(t, k1, k2, "different first user messages should yield different keys") } + +func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) { + req := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.3-codex-spark", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark") + k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index a442da33..9b3f69bc 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) promptCacheKey = strings.TrimSpace(promptCacheKey) compatPromptCacheInjected := false - if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) { - promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel) + if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel) compatPromptCacheInjected = promptCacheKey != "" } @@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( if err != nil { return nil, fmt.Errorf("convert chat completions to responses: %w", err) } - responsesReq.Model = mappedModel + responsesReq.Model = upstreamModel logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), } if compatPromptCacheInjected { @@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) } else { - result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -238,6 +244,7 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( var finalResponse *apicompat.ResponsesResponse var usage OpenAIUsage + acc := apicompat.NewBufferedResponseAccumulator() for scanner.Scan() { line := scanner.Text() @@ -255,7 +262,11 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( continue } - if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + // Accumulate delta content for fallback when terminal output is empty. + acc.ProcessEvent(&event) + + if (event.Type == "response.completed" || event.Type == "response.done" || + event.Type == "response.incomplete" || event.Type == "response.failed") && event.Response != nil { finalResponse = event.Response if event.Response.Usage != nil { @@ -284,6 +295,10 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( return nil, fmt.Errorf("upstream stream ended without terminal event") } + // When the terminal event has an empty output array, reconstruct from + // accumulated delta events so the client receives the full content. + acc.SupplementResponseOutput(finalResponse) + chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel) if s.responseHeaderFilter != nil { @@ -295,8 +310,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -308,7 +323,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, includeUsage bool, startTime time.Time, ) (*OpenAIForwardResult, error) { @@ -343,8 +359,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 3df91b56..6f53928b 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } originalModel := anthropicReq.Model applyOpenAICompatModelNormalization(&anthropicReq) + normalizedModel := anthropicReq.Model clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses @@ -60,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel) - responsesReq.Model = mappedModel + billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("normalized_model", normalizedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), ) @@ -82,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -182,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } else { // Client wants JSON: buffer the streaming response and assemble a JSON reply. - result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -230,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -303,8 +311,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -319,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -352,8 +361,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 7a636afa..38b97b11 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, &DeferredService{}, nil, + nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, @@ -931,6 +933,89 @@ func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel( require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount) } +func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingModelWhenUnmapped(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + // When channel did NOT map the model (ChannelMappedModel == OriginalModel), + // billing should use result.BillingModel (the actual model used after group + // DefaultMappedModel resolution), not the unmapped original model. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_channel_unmapped_billing", + Model: "glm", + BillingModel: "gpt-5.1", + UpstreamModel: "gpt-5.1", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + ChannelUsageFields: ChannelUsageFields{ + ChannelID: 1, + OriginalModel: "glm", + ChannelMappedModel: "glm", // channel did NOT map + BillingModelSource: BillingModelSourceChannelMapped, + }, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenMapped(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + // When channel DID map the model (ChannelMappedModel != OriginalModel), + // billing should use the channel-mapped model, honoring admin intent. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_channel_mapped_billing", + Model: "glm", + BillingModel: "gpt-5.1-codex", + UpstreamModel: "gpt-5.1-codex", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + ChannelUsageFields: ChannelUsageFields{ + ChannelID: 1, + OriginalModel: "glm", + ChannelMappedModel: "gpt-5.1", // channel mapped glm → gpt-5.1 + BillingModelSource: BillingModelSourceChannelMapped, + }, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 0a959615..65e70408 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "math/rand" "net/http" "sort" @@ -20,6 +21,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" @@ -204,6 +206,7 @@ type OpenAIUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // OpenAIForwardResult represents the result of forwarding @@ -322,6 +325,8 @@ type OpenAIGatewayService struct { openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver + resolver *ModelPricingResolver + channelService *ChannelService openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -357,6 +362,8 @@ func NewOpenAIGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, + resolver *ModelPricingResolver, + channelService *ChannelService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -384,6 +391,8 @@ func NewOpenAIGatewayService( openAITokenProvider: openAITokenProvider, toolCorrector: NewCodexToolCorrector(), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + resolver: resolver, + channelService: channelService, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -391,6 +400,74 @@ func NewOpenAIGatewayService( return svc } +// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService) +func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService) +func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段,restricted 始终返回 false。 +func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) +} + +func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + +// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 +func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} + func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { if s != nil && s.codexSnapshotThrottle != nil { return s.codexSnapshotThrottle @@ -1125,6 +1202,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C } func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 1. 尝试粘性会话命中 // Try sticky session hit if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { @@ -1140,7 +1224,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1206,6 +1290,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } + if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && + s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1218,8 +1307,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) for i := range accounts { acc := &accounts[i] @@ -1238,6 +1328,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1289,7 +1382,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + cfg := s.schedulingConfig() + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { @@ -1365,6 +1466,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1410,6 +1513,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } candidates = append(candidates, acc) } @@ -1434,6 +1540,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1488,6 +1597,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1510,6 +1622,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } return &AccountSelectionResult{ Account: fresh, WaitPlan: &AccountWaitPlan{ @@ -1814,29 +1929,31 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 对所有请求执行模型映射(包含 Codex CLI)。 - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) - reqBody["model"] = mappedModel + billingModel := account.GetMappedModel(reqModel) + if billingModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) + reqBody["model"] = billingModel bodyModified = true - markPatchSet("model", mappedModel) + markPatchSet("model", billingModel) } + upstreamModel := billingModel - // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 + // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 + // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, + // 以兼容自定义 base_url 的 OpenAI-compatible 上游。 if model, ok := reqBody["model"].(string); ok { - normalizedModel := normalizeCodexModel(model) - if normalizedModel != "" && normalizedModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, normalizedModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = normalizedModel - mappedModel = normalizedModel + upstreamModel = normalizeOpenAIModelForUpstream(account, model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel bodyModified = true - markPatchSet("model", normalizedModel) + markPatchSet("model", upstreamModel) } // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 // 确保高版本模型向低版本模型映射不报错 - if !SupportsVerbosity(normalizedModel) { + if !SupportsVerbosity(upstreamModel) { if text, ok := reqBody["text"].(map[string]any); ok { delete(text, "verbosity") } @@ -1860,7 +1977,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } if codexResult.NormalizedModel != "" { - mappedModel = codexResult.NormalizedModel + upstreamModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey @@ -1977,7 +2094,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", account.ID, account.Type, - mappedModel, + upstreamModel, reqStream, hasPreviousResponseID, ) @@ -2066,7 +2183,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco isCodexCLI, reqStream, originalModel, - mappedModel, + upstreamModel, startTime, attempt, wsLastFailureReason, @@ -2167,7 +2284,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) - wsResult.UpstreamModel = mappedModel + wsResult.UpstreamModel = upstreamModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2272,14 +2389,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var usage *OpenAIUsage var firstTokenMs *int if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } @@ -2303,7 +2420,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, - UpstreamModel: mappedModel, + UpstreamModel: upstreamModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, @@ -2430,7 +2547,11 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { - // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + // 透传模式默认保持原样代理;但 429/529 属于网关必须兜底的 + // 上游容量类错误,应先触发多账号 failover 以维持基础 SLA。 + if shouldFailoverOpenAIPassthroughResponse(resp.StatusCode) { + return nil, s.handleFailoverErrorResponsePassthrough(ctx, resp, c, account, body) + } return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) } @@ -2613,6 +2734,58 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( return req, nil } +func shouldFailoverOpenAIPassthroughResponse(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, 529: + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + return &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + ResponseHeaders: resp.Header.Clone(), + } +} + func (s *OpenAIGatewayService) handleErrorResponsePassthrough( ctx context.Context, resp *http.Response, @@ -3731,6 +3904,16 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { *usage = parsedUsage } + // When the terminal event has an empty output array, reconstruct + // output from accumulated delta events so the client gets full content. + // gjson Array() returns empty slice for null, missing, or empty arrays. + if len(gjson.GetBytes(finalResponse, "output").Array()) == 0 { + if outputJSON, reconstructed := reconstructResponseOutputFromSSE(bodyText); reconstructed { + if patched, err := sjson.SetRawBytes(finalResponse, "output", outputJSON); err == nil { + finalResponse = patched + } + } + } body = finalResponse if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) @@ -3832,6 +4015,34 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { return nil, false } +// reconstructResponseOutputFromSSE scans raw SSE body text for delta events and +// returns a JSON-encoded output array reconstructed from accumulated deltas. +// Returns (nil, false) if no content was found in deltas. +func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { + acc := apicompat.NewBufferedResponseAccumulator() + lines := strings.Split(bodyText, "\n") + for _, line := range lines { + data, ok := extractOpenAISSEDataLine(line) + if !ok || data == "" || data == "[DONE]" { + continue + } + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + acc.ProcessEvent(&event) + } + if !acc.HasContent() { + return nil, false + } + output := acc.BuildOutput() + outputJSON, err := json.Marshal(output) + if err != nil { + return nil, false + } + return outputJSON, true +} + func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") @@ -4110,6 +4321,7 @@ type OpenAIRecordUsageInput struct { IPAddress string // 请求的客户端 IP 地址 RequestPayloadHash string APIKeyService APIKeyQuotaUpdater + ChannelUsageFields } // RecordUsage records usage and deducts balance @@ -4140,10 +4352,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, } // Get rate multiplier - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { resolver := s.userGroupRateResolver if resolver == nil { @@ -4152,12 +4368,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } + var cost *CostBreakdown + var err error billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if result.BillingModel != "" { + billingModel = strings.TrimSpace(result.BillingModel) + } + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" && input.ChannelMappedModel != input.OriginalModel { + billingModel = input.ChannelMappedModel + } + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } else { + cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + } if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -4173,36 +4414,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } + usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: result.Model, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ServiceTier: result.ServiceTier, - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: actualInputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - Stream: result.Stream, - OpenAIWSMode: result.OpenAIWSMode, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - CreatedAt: time.Now(), + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: requestedModel, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ServiceTier: result.ServiceTier, + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: actualInputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + if cost != nil { + usageLog.InputCost = cost.InputCost + usageLog.OutputCost = cost.OutputCost + usageLog.ImageOutputCost = cost.ImageOutputCost + usageLog.CacheCreationCost = cost.CacheCreationCost + usageLog.CacheReadCost = cost.CacheReadCost + usageLog.TotalCost = cost.TotalCost + usageLog.ActualCost = cost.ActualCost + } + usageLog.RateMultiplier = multiplier + usageLog.AccountRateMultiplier = &accountRateMultiplier + usageLog.BillingType = billingType + usageLog.Stream = result.Stream + usageLog.OpenAIWSMode = result.OpenAIWSMode + usageLog.DurationMs = &durationMs + usageLog.FirstTokenMs = result.FirstTokenMs + usageLog.CreatedAt = time.Now() + // 设置渠道信息 + usageLog.ChannelID = optionalInt64Ptr(input.ChannelID) + usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain) + // 设置计费模式 + if cost != nil && cost.BillingMode != "" { + billingMode := cost.BillingMode + usageLog.BillingMode = &billingMode + } else { + billingMode := string(BillingModeToken) + usageLog.BillingMode = &billingMode } // 添加 UserAgent if input.UserAgent != "" { diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index edbb968b..cda7e369 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -74,13 +74,64 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * Credentials: map[string]any{}, } - withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") - if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1") + withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) + if withoutDefault != "gpt-5.1" { + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") } - withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") - if got := normalizeCodexModel(withDefault); got != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4") + withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + if withDefault != "gpt-5.4" { + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4") + } +} + +func TestNormalizeCodexModel(t *testing.T) { + cases := map[string]string{ + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + } + + for input, expected := range cases { + if got := normalizeCodexModel(input); got != expected { + t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected) + } + } +} + +func TestNormalizeOpenAIModelForUpstream(t *testing.T) { + tests := []struct { + name string + account *Account + model string + want string + }{ + { + name: "oauth keeps codex normalization behavior", + account: &Account{Type: AccountTypeOAuth}, + model: "gemini-3-flash-preview", + want: "gpt-5.1", + }, + { + name: "apikey preserves custom compatible model", + account: &Account{Type: AccountTypeAPIKey}, + model: "gemini-3-flash-preview", + want: "gemini-3-flash-preview", + }, + { + name: "apikey preserves official non codex model", + account: &Account{Type: AccountTypeAPIKey}, + model: "gpt-4.1", + want: "gpt-4.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeOpenAIModelForUpstream(tt.account, tt.model); got != tt.want { + t.Fatalf("normalizeOpenAIModelForUpstream(...) = %q, want %q", got, tt.want) + } + }) } } diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 97fa218d..69c9de42 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -48,6 +48,22 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc return u.Do(req, proxyURL, accountID, accountConcurrency) } +type openAIPassthroughFailoverRepo struct { + stubOpenAIAccountRepo + rateLimitCalls []time.Time + overloadCalls []time.Time +} + +func (r *openAIPassthroughFailoverRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + r.rateLimitCalls = append(r.rateLimitCalls, resetAt) + return nil +} + +func (r *openAIPassthroughFailoverRepo) SetOverloaded(_ context.Context, _ int64, until time.Time) error { + r.overloadCalls = append(r.overloadCalls, until) + return nil +} + var structuredLogCaptureMu sync.Mutex type inMemoryLogSink struct { @@ -527,6 +543,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF _, err := svc.Forward(context.Background(), c, account, originalBody) require.Error(t, err) + require.True(t, c.Writer.Written(), "非 429/529 的 passthrough 错误应继续原样写回客户端") + require.Equal(t, http.StatusBadRequest, rec.Code) // should append an upstream error event with passthrough=true v, ok := c.Get(OpsUpstreamErrorsKey) @@ -535,55 +553,145 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, ok) require.NotEmpty(t, arr) require.True(t, arr[len(arr)-1].Passthrough) + require.Equal(t, "http_error", arr[len(arr)-1].Kind) } -func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) { +func TestOpenAIGatewayService_OpenAIPassthrough_429And529TriggerFailover(t *testing.T) { gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) - resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{ - "Content-Type": []string{"application/json"}, - "x-request-id": []string{"rid-rate-limit"}, + + newAccount := func(accountType string) *Account { + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: accountType, + Concurrency: 1, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + switch accountType { + case AccountTypeOAuth: + account.Credentials = map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"} + case AccountTypeAPIKey: + account.Credentials = map[string]any{"api_key": "sk-test"} + } + return account + } + + testCases := []struct { + name string + accountType string + statusCode int + body string + assertRepo func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) + }{ + { + name: "oauth_429_rate_limit", + accountType: AccountTypeOAuth, + statusCode: http.StatusTooManyRequests, + body: func() string { + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt) + }(), + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) { + require.Len(t, repo.rateLimitCalls, 1) + require.Empty(t, repo.overloadCalls) + require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour) + }, + }, + { + name: "oauth_529_overload", + accountType: AccountTypeOAuth, + statusCode: 529, + body: `{"error":{"message":"server overloaded","type":"server_error"}}`, + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) { + require.Empty(t, repo.rateLimitCalls) + require.Len(t, repo.overloadCalls, 1) + require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second) + }, + }, + { + name: "apikey_429_rate_limit", + accountType: AccountTypeAPIKey, + statusCode: http.StatusTooManyRequests, + body: func() string { + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt) + }(), + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) { + require.Len(t, repo.rateLimitCalls, 1) + require.Empty(t, repo.overloadCalls) + require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour) + }, + }, + { + name: "apikey_529_overload", + accountType: AccountTypeAPIKey, + statusCode: 529, + body: `{"error":{"message":"server overloaded","type":"server_error"}}`, + assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) { + require.Empty(t, repo.rateLimitCalls) + require.Len(t, repo.overloadCalls, 1) + require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second) + }, }, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))), - } - upstream := &httpUpstreamRecorder{resp: resp} - repo := &openAIWSRateLimitSignalRepo{} - rateSvc := &RateLimitService{accountRepo: repo} - - svc := &OpenAIGatewayService{ - cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, - httpUpstream: upstream, - rateLimitService: rateSvc, } - account := &Account{ - ID: 123, - Name: "acc", - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_passthrough": true}, - Status: StatusActive, - Schedulable: true, - RateMultiplier: f64p(1), - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - _, err := svc.Forward(context.Background(), c, account, originalBody) - require.Error(t, err) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - require.Contains(t, rec.Body.String(), "usage_limit_reached") - require.Len(t, repo.rateLimitCalls, 1) - require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) + resp := &http.Response{ + StatusCode: tc.statusCode, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-failover"}, + }, + Body: io.NopCloser(strings.NewReader(tc.body)), + } + upstream := &httpUpstreamRecorder{resp: resp} + repo := &openAIPassthroughFailoverRepo{} + rateSvc := &RateLimitService{ + accountRepo: repo, + cfg: &config.Config{ + RateLimit: config.RateLimitConfig{OverloadCooldownMinutes: 10}, + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + rateLimitService: rateSvc, + } + + account := newAccount(tc.accountType) + start := time.Now() + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, tc.statusCode, failoverErr.StatusCode) + require.False(t, c.Writer.Written(), "429/529 passthrough 应返回 failover 错误给上层换号,而不是直接向客户端写响应") + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + arr, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, arr) + require.True(t, arr[len(arr)-1].Passthrough) + require.Equal(t, "failover", arr[len(arr)-1].Kind) + require.Equal(t, tc.statusCode, arr[len(arr)-1].UpstreamStatusCode) + + tc.assertRepo(t, repo, start) + }) + } } func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 0f004b01..dc094d43 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -3,30 +3,15 @@ package service import ( "context" "crypto/subtle" - "encoding/json" - "io" "log/slog" "net/http" - "regexp" - "sort" - "strconv" "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) -var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" - -var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) - -type soraSessionChunk struct { - index int - value string -} - // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -127,18 +112,19 @@ type OpenAIExchangeCodeInput struct { // OpenAITokenInfo represents the token information for OpenAI type OpenAITokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token,omitempty"` - ExpiresIn int64 `json:"expires_in"` - ExpiresAt int64 `json:"expires_at"` - ClientID string `json:"client_id,omitempty"` - Email string `json:"email,omitempty"` - ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` - ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` - OrganizationID string `json:"organization_id,omitempty"` - PlanType string `json:"plan_type,omitempty"` - PrivacyMode string `json:"privacy_mode,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token,omitempty"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + ClientID string `json:"client_id,omitempty"` + Email string `json:"email,omitempty"` + ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` + ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` + OrganizationID string `json:"organization_id,omitempty"` + PlanType string `json:"plan_type,omitempty"` + SubscriptionExpiresAt string `json:"subscription_expires_at,omitempty"` + PrivacyMode string `json:"privacy_mode,omitempty"` } // ExchangeCode exchanges authorization code for tokens @@ -214,6 +200,8 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch tokenInfo.PlanType = userInfo.PlanType } + s.enrichTokenInfo(ctx, tokenInfo, proxyURL) + return tokenInfo, nil } @@ -222,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") } -// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id. func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { @@ -259,242 +247,46 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre tokenInfo.PlanType = userInfo.PlanType } - // id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全 - if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { - // 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号 - orgID := tokenInfo.OrganizationID - if orgID == "" { - if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil { - orgID = atClaims.OpenAIAuth.POID - } - } - if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil { - if tokenInfo.PlanType == "" && info.PlanType != "" { - tokenInfo.PlanType = info.PlanType - } - if tokenInfo.Email == "" && info.Email != "" { - tokenInfo.Email = info.Email - } - } - } - - // 尝试设置隐私(关闭训练数据共享),best-effort - if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { - tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) - } + s.enrichTokenInfo(ctx, tokenInfo, proxyURL) return tokenInfo, nil } -// ExchangeSoraSessionToken exchanges Sora session_token to access_token. -func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { - sessionToken = normalizeSoraSessionTokenInput(sessionToken) - if strings.TrimSpace(sessionToken) == "" { - return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") +// enrichTokenInfo 通过 ChatGPT backend-api 补全 tokenInfo 并设置隐私(best-effort)。 +// 从 accounts/check 获取最新 plan_type、subscription_expires_at、email, +// 然后尝试关闭训练数据共享。适用于所有获取/刷新 token 的路径。 +func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *OpenAITokenInfo, proxyURL string) { + if tokenInfo.AccessToken == "" || s.privacyClientFactory == nil { + return } - proxyURL, err := s.resolveProxyURL(ctx, proxyID) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) - if err != nil { - return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) - } - req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) - req.Header.Set("Accept", "application/json") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 120 * time.Second, - }) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) - } - resp, err := client.Do(req) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) - } - defer func() { _ = resp.Body.Close() }() - - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if resp.StatusCode != http.StatusOK { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - - var sessionResp struct { - AccessToken string `json:"accessToken"` - Expires string `json:"expires"` - User struct { - Email string `json:"email"` - Name string `json:"name"` - } `json:"user"` - } - if err := json.Unmarshal(body, &sessionResp); err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) - } - if strings.TrimSpace(sessionResp.AccessToken) == "" { - return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") - } - - expiresAt := time.Now().Add(time.Hour).Unix() - if strings.TrimSpace(sessionResp.Expires) != "" { - if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { - expiresAt = parsed.Unix() + // 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号 + orgID := tokenInfo.OrganizationID + if orgID == "" { + if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil { + orgID = atClaims.OpenAIAuth.POID } } - expiresIn := expiresAt - time.Now().Unix() - if expiresIn < 0 { - expiresIn = 0 + if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil { + if info.PlanType != "" { + tokenInfo.PlanType = info.PlanType + } + if info.SubscriptionExpiresAt != "" { + tokenInfo.SubscriptionExpiresAt = info.SubscriptionExpiresAt + } + if tokenInfo.Email == "" && info.Email != "" { + tokenInfo.Email = info.Email + } } - return &OpenAITokenInfo{ - AccessToken: strings.TrimSpace(sessionResp.AccessToken), - ExpiresIn: expiresIn, - ExpiresAt: expiresAt, - ClientID: openai.SoraClientID, - Email: strings.TrimSpace(sessionResp.User.Email), - }, nil + // 尝试设置隐私(关闭训练数据共享),best-effort + tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) } -func normalizeSoraSessionTokenInput(raw string) string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "" - } - - matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) - if len(matches) == 0 { - return sanitizeSessionToken(trimmed) - } - - chunkMatches := make([]soraSessionChunk, 0, len(matches)) - singleValues := make([]string, 0, len(matches)) - - for _, match := range matches { - if len(match) < 3 { - continue - } - - value := sanitizeSessionToken(match[2]) - if value == "" { - continue - } - - if strings.TrimSpace(match[1]) == "" { - singleValues = append(singleValues, value) - continue - } - - idx, err := strconv.Atoi(strings.TrimSpace(match[1])) - if err != nil || idx < 0 { - continue - } - chunkMatches = append(chunkMatches, soraSessionChunk{ - index: idx, - value: value, - }) - } - - if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { - return merged - } - - if len(singleValues) > 0 { - return singleValues[len(singleValues)-1] - } - - return "" -} - -func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { - if len(chunks) == 0 { - return "" - } - - byIndex := make(map[int]string, len(chunks)) - for _, chunk := range chunks { - byIndex[chunk.index] = chunk.value - } - - if _, ok := byIndex[0]; !ok { - return "" - } - if requireComplete { - for idx := 0; idx <= requiredMaxIndex; idx++ { - if _, ok := byIndex[idx]; !ok { - return "" - } - } - } - - orderedIndexes := make([]int, 0, len(byIndex)) - for idx := range byIndex { - orderedIndexes = append(orderedIndexes, idx) - } - sort.Ints(orderedIndexes) - - var builder strings.Builder - for _, idx := range orderedIndexes { - if _, err := builder.WriteString(byIndex[idx]); err != nil { - return "" - } - } - return sanitizeSessionToken(builder.String()) -} - -func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { - if len(chunks) == 0 { - return "" - } - - requiredMaxIndex := 0 - for _, chunk := range chunks { - if chunk.index > requiredMaxIndex { - requiredMaxIndex = chunk.index - } - } - - groupStarts := make([]int, 0, len(chunks)) - for idx, chunk := range chunks { - if chunk.index == 0 { - groupStarts = append(groupStarts, idx) - } - } - - if len(groupStarts) == 0 { - return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) - } - - for i := len(groupStarts) - 1; i >= 0; i-- { - start := groupStarts[i] - end := len(chunks) - if i+1 < len(groupStarts) { - end = groupStarts[i+1] - } - if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { - return merged - } - } - - return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) -} - -func sanitizeSessionToken(raw string) string { - token := strings.TrimSpace(raw) - token = strings.Trim(token, "\"'`") - token = strings.TrimSuffix(token, ";") - return strings.TrimSpace(token) -} - -// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +// RefreshAccountToken refreshes token for an OpenAI OAuth account func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + if account.Platform != PlatformOpenAI { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") } if account.Type != AccountTypeOAuth { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") @@ -567,6 +359,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.PlanType != "" { creds["plan_type"] = tokenInfo.PlanType } + if tokenInfo.SubscriptionExpiresAt != "" { + creds["subscription_expires_at"] = tokenInfo.SubscriptionExpiresAt + } if strings.TrimSpace(tokenInfo.ClientID) != "" { creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) } @@ -579,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } -func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { - if proxyID == nil { - return "", nil - } - proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) - if err != nil { - return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) - } - if proxy == nil { - return "", nil - } - return proxy.URL(), nil -} - func normalizeOpenAIOAuthPlatform(platform string) string { - switch strings.ToLower(strings.TrimSpace(platform)) { - case PlatformSora: - return openai.OAuthPlatformSora - default: - return openai.OAuthPlatformOpenAI - } + return openai.OAuthPlatformOpenAI } diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go index 5f26903d..f3b507ca 100644 --- a/backend/internal/service/openai_oauth_service_auth_url_test.go +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { require.True(t, ok) require.Equal(t, openai.ClientID, session.ClientID) } - -// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 -// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 -func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) - defer svc.Stop() - - result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) - require.NoError(t, err) - require.NotEmpty(t, result.AuthURL) - require.NotEmpty(t, result.SessionID) - - parsed, err := url.Parse(result.AuthURL) - require.NoError(t, err) - q := parsed.Query() - require.Equal(t, openai.ClientID, q.Get("client_id")) - require.Empty(t, q.Get("codex_cli_simplified_flow")) - - session, ok := svc.sessionStore.Get(result.SessionID) - require.True(t, ok) - require.Equal(t, openai.ClientID, session.ClientID) -} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go deleted file mode 100644 index 08da8557..00000000 --- a/backend/internal/service/openai_oauth_service_sora_session_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package service - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/stretchr/testify/require" -) - -type openaiOAuthClientNoopStub struct{} - -func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) - require.NoError(t, err) - require.NotNil(t, info) - require.Equal(t, "at-token", info.AccessToken) - require.Equal(t, "demo@example.com", info.Email) - require.Greater(t, info.ExpiresAt, int64(0)) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) - require.Error(t, err) - require.Contains(t, err.Error(), "missing access token") -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "set-cookie", - "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", - "set-cookie", - "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", - "set-cookie", - "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} diff --git a/backend/internal/service/openai_privacy_service.go b/backend/internal/service/openai_privacy_service.go index 6bc71ab9..da6dbefc 100644 --- a/backend/internal/service/openai_privacy_service.go +++ b/backend/internal/service/openai_privacy_service.go @@ -56,6 +56,10 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Origin", "https://chatgpt.com"). SetHeader("Referer", "https://chatgpt.com/"). + SetHeader("Accept", "application/json"). + SetHeader("sec-fetch-mode", "cors"). + SetHeader("sec-fetch-site", "same-origin"). + SetHeader("sec-fetch-dest", "empty"). SetQueryParam("feature", "training_allowed"). SetQueryParam("value", "false"). Patch(openAISettingsURL) @@ -84,8 +88,9 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto // ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息 type ChatGPTAccountInfo struct { - PlanType string - Email string + PlanType string + Email string + SubscriptionExpiresAt string // entitlement.expires_at (RFC3339) } const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" @@ -138,14 +143,20 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac // 优先匹配 orgID 对应的账号(access_token JWT 中的 poid) if orgID != "" { - if matched := extractPlanFromAccount(accounts, orgID); matched != "" { - info.PlanType = matched + if acctRaw, ok := accounts[orgID]; ok { + if acct, ok := acctRaw.(map[string]any); ok { + fillAccountInfo(info, acct) + } } } // 未匹配到时,遍历所有账号:优先 is_default,次选非 free if info.PlanType == "" { - var defaultPlan, paidPlan, anyPlan string + type candidate struct { + planType string + expiresAt string + } + var defaultC, paidC, anyC candidate for _, acctRaw := range accounts { acct, ok := acctRaw.(map[string]any) if !ok { @@ -155,26 +166,27 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac if planType == "" { continue } - if anyPlan == "" { - anyPlan = planType + ea := extractEntitlementExpiresAt(acct) + if anyC.planType == "" { + anyC = candidate{planType, ea} } if account, ok := acct["account"].(map[string]any); ok { if isDefault, _ := account["is_default"].(bool); isDefault { - defaultPlan = planType + defaultC = candidate{planType, ea} } } - if !strings.EqualFold(planType, "free") && paidPlan == "" { - paidPlan = planType + if !strings.EqualFold(planType, "free") && paidC.planType == "" { + paidC = candidate{planType, ea} } } // 优先级:default > 非 free > 任意 switch { - case defaultPlan != "": - info.PlanType = defaultPlan - case paidPlan != "": - info.PlanType = paidPlan + case defaultC.planType != "": + info.PlanType, info.SubscriptionExpiresAt = defaultC.planType, defaultC.expiresAt + case paidC.planType != "": + info.PlanType, info.SubscriptionExpiresAt = paidC.planType, paidC.expiresAt default: - info.PlanType = anyPlan + info.PlanType, info.SubscriptionExpiresAt = anyC.planType, anyC.expiresAt } } @@ -183,21 +195,14 @@ func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFac return nil } - slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID) + slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "subscription_expires_at", info.SubscriptionExpiresAt, "org_id", orgID) return info } -// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type -func extractPlanFromAccount(accounts map[string]any, accountKey string) string { - acctRaw, ok := accounts[accountKey] - if !ok { - return "" - } - acct, ok := acctRaw.(map[string]any) - if !ok { - return "" - } - return extractPlanType(acct) +// fillAccountInfo 从单个 account 对象中提取 plan_type 和 subscription_expires_at +func fillAccountInfo(info *ChatGPTAccountInfo, acct map[string]any) { + info.PlanType = extractPlanType(acct) + info.SubscriptionExpiresAt = extractEntitlementExpiresAt(acct) } // extractPlanType 从单个 account 对象中提取 plan_type @@ -215,6 +220,17 @@ func extractPlanType(acct map[string]any) string { return "" } +// extractEntitlementExpiresAt 从 entitlement 中提取 expires_at。 +// 预期为 RFC3339 字符串格式,如 "2026-05-02T20:32:12+00:00"。 +func extractEntitlementExpiresAt(acct map[string]any) string { + entitlement, ok := acct["entitlement"].(map[string]any) + if !ok { + return "" + } + ea, _ := entitlement["expires_at"].(string) + return ea +} + func truncate(s string, n int) string { if len(s) <= n { return s diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 69477ce7..e438588e 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() { // OpenAITokenCache token cache interface. type OpenAITokenCache = GeminiTokenCache -// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts. +// OpenAITokenProvider manages access_token for OpenAI OAuth accounts. type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache @@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai/sora oauth account") + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou p.metrics.refreshRequests.Add(1) p.metrics.touchNow() - // Sora accounts skip OpenAI OAuth refresh and keep existing token path. - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) - refreshFailed = true - } else { - result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) - if err != nil { - if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { - return "", err - } - slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else if result.LockHeld { - if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { - p.metrics.lockContention.Add(1) - p.metrics.touchNow() - token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) - if waitErr != nil { - return "", waitErr - } - if strings.TrimSpace(token) != "" { - slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) - return token, nil - } - } - } else if result.Refreshed { - p.metrics.refreshSuccess.Add(1) - account = result.Account - expiresAt = account.GetCredentialAsTime("expires_at") - } else { - account = result.Account - expiresAt = account.GetCredentialAsTime("expires_at") + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err } + slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) + refreshFailed = true + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + } + } else if result.Refreshed { + p.metrics.refreshSuccess.Add(1) + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") } } else if needsRefresh && p.tokenCache != nil { // Backward-compatible test path when refreshAPI is not injected. diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index 1cd92367..e81fb465 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai/sora oauth account") + require.Contains(t, err.Error(), "not an openai oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai/sora oauth account") + require.Contains(t, err.Error(), "not an openai oauth account") require.Empty(t, token) } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 4f1837c4..83849bf3 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - mappedModel := account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } - if mappedModel != originalModel { - next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) + if upstreamModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } @@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } + mappedModel = normalizeOpenAIModelForUpstream(account, account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 8c5c9368..3834dcb7 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fdabbafd..c0e814ab 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry if s.gatewayService == nil { return nil, fmt.Errorf("gateway service not available") } - return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制 + return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制 default: return nil, fmt.Errorf("unsupported retry type: %s", reqType) } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 5623d4b7..3b3f31c3 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -70,7 +70,8 @@ type LiteLLMModelPricing struct { LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -94,6 +95,7 @@ type LiteLLMRawEntry struct { Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` OutputCostPerImage *float64 `json:"output_cost_per_image"` + OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"` } // PricingService 动态价格服务 @@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } + if entry.OutputCostPerImageToken != nil { + pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken + } result[modelName] = pricing } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index aa0ae200..4f5b57cc 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -161,6 +161,16 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc shouldDisable = true break } + // OpenAI: {"detail":"Unauthorized"} 表示 token 完全无效(非标准 OpenAI 错误格式),直接标记 error + if account.Platform == PlatformOpenAI && gjson.GetBytes(responseBody, "detail").String() == "Unauthorized" { + msg := "Unauthorized (401): account authentication failed permanently" + if upstreamMsg != "" { + msg = "Unauthorized (401): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true + break + } // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。 // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。 if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity { diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b22da752..9ced6201 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -131,9 +131,9 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ return nil, errors.New("count must be greater than 0") } - // 邀请码类型不需要数值,其他类型需要 - if req.Type != RedeemTypeInvitation && req.Value <= 0 { - return nil, errors.New("value must be greater than 0") + // 邀请码类型不需要数值,其他类型需要非零值(支持负数用于退款) + if req.Type != RedeemTypeInvitation && req.Value == 0 { + return nil, errors.New("value must not be zero") } if req.Count > 1000 { @@ -188,8 +188,8 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error if code.Type == "" { code.Type = RedeemTypeBalance } - if code.Type != RedeemTypeInvitation && code.Value <= 0 { - return errors.New("value must be greater than 0") + if code.Type != RedeemTypeInvitation && code.Value == 0 { + return errors.New("value must not be zero") } if code.Status == "" { code.Status = StatusUnused @@ -292,7 +292,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( if err != nil { return nil, fmt.Errorf("get user: %w", err) } - _ = user // 使用变量避免未使用错误 // 使用数据库事务保证兑换码标记与权益发放的原子性 tx, err := s.entClient.Tx(ctx) @@ -316,31 +315,46 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 执行兑换逻辑(兑换码已被锁定,此时可安全操作) switch redeemCode.Type { case RedeemTypeBalance: - // 增加用户余额 - if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { + amount := redeemCode.Value + // 负数为退款扣减,余额最低为 0 + if amount < 0 && user.Balance+amount < 0 { + amount = -user.Balance + } + if err := s.userRepo.UpdateBalance(txCtx, userID, amount); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } case RedeemTypeConcurrency: - // 增加用户并发数 - if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { + delta := int(redeemCode.Value) + // 负数为退款扣减,并发数最低为 0 + if delta < 0 && user.Concurrency+delta < 0 { + delta = -user.Concurrency + } + if err := s.userRepo.UpdateConcurrency(txCtx, userID, delta); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } case RedeemTypeSubscription: validityDays := redeemCode.ValidityDays - if validityDays <= 0 { - validityDays = 30 - } - _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ - UserID: userID, - GroupID: *redeemCode.GroupID, - ValidityDays: validityDays, - AssignedBy: 0, // 系统分配 - Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), - }) - if err != nil { - return nil, fmt.Errorf("assign or extend subscription: %w", err) + if validityDays < 0 { + // 负数天数:缩短订阅,减到 0 则取消订阅 + if err := s.reduceOrCancelSubscription(txCtx, userID, *redeemCode.GroupID, -validityDays, redeemCode.Code); err != nil { + return nil, fmt.Errorf("reduce or cancel subscription: %w", err) + } + } else { + if validityDays == 0 { + validityDays = 30 + } + _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: *redeemCode.GroupID, + ValidityDays: validityDays, + AssignedBy: 0, // 系统分配 + Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code), + }) + if err != nil { + return nil, fmt.Errorf("assign or extend subscription: %w", err) + } } default: @@ -475,3 +489,51 @@ func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit } return codes, nil } + +// reduceOrCancelSubscription 缩短订阅天数,剩余天数 <= 0 时取消订阅 +func (s *RedeemService) reduceOrCancelSubscription(ctx context.Context, userID, groupID int64, reduceDays int, code string) error { + sub, err := s.subscriptionService.userSubRepo.GetByUserIDAndGroupID(ctx, userID, groupID) + if err != nil { + return ErrSubscriptionNotFound + } + + now := time.Now() + remaining := int(sub.ExpiresAt.Sub(now).Hours() / 24) + if remaining < 0 { + remaining = 0 + } + + notes := fmt.Sprintf("通过兑换码 %s 退款扣减 %d 天", code, reduceDays) + + if remaining <= reduceDays { + // 剩余天数不足,直接取消订阅 + if err := s.subscriptionService.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired); err != nil { + return fmt.Errorf("cancel subscription: %w", err) + } + // 设置过期时间为当前时间 + if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, now); err != nil { + return fmt.Errorf("set subscription expiry: %w", err) + } + } else { + // 缩短天数 + newExpiresAt := sub.ExpiresAt.AddDate(0, 0, -reduceDays) + if err := s.subscriptionService.userSubRepo.ExtendExpiry(ctx, sub.ID, newExpiresAt); err != nil { + return fmt.Errorf("reduce subscription: %w", err) + } + } + + // 追加备注 + newNotes := sub.Notes + if newNotes != "" { + newNotes += "\n" + } + newNotes += notes + if err := s.subscriptionService.userSubRepo.UpdateNotes(ctx, sub.ID, newNotes); err != nil { + return fmt.Errorf("update subscription notes: %w", err) + } + + // 失效缓存 + s.subscriptionService.InvalidateSubCache(userID, groupID) + + return nil +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4c9540f1..d1330abb 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -152,6 +152,14 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int return s.accountRepo.GetByID(fallbackCtx, accountID) } +// GetGroupByID 获取分组信息(供调度器使用) +func (s *SchedulerSnapshotService) GetGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if s.groupRepo == nil { + return nil, nil + } + return s.groupRepo.GetByID(ctx, groupID) +} + // UpdateAccountInCache 立即更新 Redis 中单个账号的数据(用于模型限流后立即生效) func (s *SchedulerSnapshotService) UpdateAccountInCache(ctx context.Context, account *Account) error { if s.cache == nil || account == nil { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 1a24bad1..b7145121 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -22,8 +22,6 @@ import ( var ( ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") - ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") - ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ErrDefaultSubGroupInvalid = infraerrors.BadRequest( "DEFAULT_SUBSCRIPTION_GROUP_INVALID", "default subscription group must exist and be subscription type", @@ -104,7 +102,6 @@ type SettingService struct { defaultSubGroupReader DefaultSubscriptionGroupReader cfg *config.Config onUpdate func() // Callback when settings are updated (for cache invalidation) - onS3Update func() // Callback when Sora S3 settings are updated version string // Application version } @@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, - SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, @@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, @@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } -// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 -func (s *SettingService) SetOnS3UpdateCallback(callback func()) { - s.onS3Update = callback -} - // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` @@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, @@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) - updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints @@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", - SettingKeySoraClientEnabled: "false", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), @@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", @@ -1542,6 +1527,18 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be if !validScopes[rule.Scope] { return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope) } + // Validate model_whitelist patterns + for j, pattern := range rule.ModelWhitelist { + trimmed := strings.TrimSpace(pattern) + if trimmed == "" { + return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j) + } + settings.Rules[i].ModelWhitelist[j] = trimmed + } + // Validate fallback_action + if rule.FallbackAction != "" && !validActions[rule.FallbackAction] { + return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction) + } } data, err := json.Marshal(settings) @@ -1583,607 +1580,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } - -type soraS3ProfilesStore struct { - ActiveProfileID string `json:"active_profile_id"` - Items []soraS3ProfileStoreItem `json:"items"` -} - -type soraS3ProfileStoreItem struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) -func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - profiles, err := s.ListSoraS3Profiles(ctx) - if err != nil { - return nil, err - } - - activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) - if activeProfile == nil { - return &SoraS3Settings{}, nil - } - - return &SoraS3Settings{ - Enabled: activeProfile.Enabled, - Endpoint: activeProfile.Endpoint, - Region: activeProfile.Region, - Bucket: activeProfile.Bucket, - AccessKeyID: activeProfile.AccessKeyID, - SecretAccessKey: activeProfile.SecretAccessKey, - SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, - Prefix: activeProfile.Prefix, - ForcePathStyle: activeProfile.ForcePathStyle, - CDNURL: activeProfile.CDNURL, - DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, - }, nil -} - -// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) -func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { - if settings == nil { - return fmt.Errorf("settings cannot be nil") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return err - } - - now := time.Now().UTC().Format(time.RFC3339) - activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) - if activeIndex < 0 { - activeID := "default" - if hasSoraS3ProfileID(store.Items, activeID) { - activeID = fmt.Sprintf("default-%d", time.Now().Unix()) - } - store.Items = append(store.Items, soraS3ProfileStoreItem{ - ProfileID: activeID, - Name: "Default", - UpdatedAt: now, - }) - store.ActiveProfileID = activeID - activeIndex = len(store.Items) - 1 - } - - active := store.Items[activeIndex] - active.Enabled = settings.Enabled - active.Endpoint = strings.TrimSpace(settings.Endpoint) - active.Region = strings.TrimSpace(settings.Region) - active.Bucket = strings.TrimSpace(settings.Bucket) - active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) - active.Prefix = strings.TrimSpace(settings.Prefix) - active.ForcePathStyle = settings.ForcePathStyle - active.CDNURL = strings.TrimSpace(settings.CDNURL) - active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) - if settings.SecretAccessKey != "" { - active.SecretAccessKey = settings.SecretAccessKey - } - active.UpdatedAt = now - store.Items[activeIndex] = active - - return s.persistSoraS3ProfilesStore(ctx, store) -} - -// ListSoraS3Profiles 获取 Sora S3 多配置列表 -func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - return convertSoraS3ProfilesStore(store), nil -} - -// CreateSoraS3Profile 创建 Sora S3 配置 -func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { - if profile == nil { - return nil, fmt.Errorf("profile cannot be nil") - } - - profileID := strings.TrimSpace(profile.ProfileID) - if profileID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - name := strings.TrimSpace(profile.Name) - if name == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - if hasSoraS3ProfileID(store.Items, profileID) { - return nil, ErrSoraS3ProfileExists - } - - now := time.Now().UTC().Format(time.RFC3339) - store.Items = append(store.Items, soraS3ProfileStoreItem{ - ProfileID: profileID, - Name: name, - Enabled: profile.Enabled, - Endpoint: strings.TrimSpace(profile.Endpoint), - Region: strings.TrimSpace(profile.Region), - Bucket: strings.TrimSpace(profile.Bucket), - AccessKeyID: strings.TrimSpace(profile.AccessKeyID), - SecretAccessKey: profile.SecretAccessKey, - Prefix: strings.TrimSpace(profile.Prefix), - ForcePathStyle: profile.ForcePathStyle, - CDNURL: strings.TrimSpace(profile.CDNURL), - DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }) - - if setActive || store.ActiveProfileID == "" { - store.ActiveProfileID = profileID - } - - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - created := findSoraS3ProfileByID(profiles.Items, profileID) - if created == nil { - return nil, ErrSoraS3ProfileNotFound - } - return created, nil -} - -// UpdateSoraS3Profile 更新 Sora S3 配置 -func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { - if profile == nil { - return nil, fmt.Errorf("profile cannot be nil") - } - - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return nil, ErrSoraS3ProfileNotFound - } - - target := store.Items[targetIndex] - name := strings.TrimSpace(profile.Name) - if name == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") - } - target.Name = name - target.Enabled = profile.Enabled - target.Endpoint = strings.TrimSpace(profile.Endpoint) - target.Region = strings.TrimSpace(profile.Region) - target.Bucket = strings.TrimSpace(profile.Bucket) - target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) - target.Prefix = strings.TrimSpace(profile.Prefix) - target.ForcePathStyle = profile.ForcePathStyle - target.CDNURL = strings.TrimSpace(profile.CDNURL) - target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) - if profile.SecretAccessKey != "" { - target.SecretAccessKey = profile.SecretAccessKey - } - target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - store.Items[targetIndex] = target - - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - updated := findSoraS3ProfileByID(profiles.Items, targetID) - if updated == nil { - return nil, ErrSoraS3ProfileNotFound - } - return updated, nil -} - -// DeleteSoraS3Profile 删除 Sora S3 配置 -func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return ErrSoraS3ProfileNotFound - } - - store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) - if store.ActiveProfileID == targetID { - store.ActiveProfileID = "" - if len(store.Items) > 0 { - store.ActiveProfileID = store.Items[0].ProfileID - } - } - - return s.persistSoraS3ProfilesStore(ctx, store) -} - -// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 -func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return nil, ErrSoraS3ProfileNotFound - } - - store.ActiveProfileID = targetID - store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) - if active == nil { - return nil, ErrSoraS3ProfileNotFound - } - return active, nil -} - -func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { - raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) - if err == nil { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return &soraS3ProfilesStore{}, nil - } - var store soraS3ProfilesStore - if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil - } - normalized := normalizeSoraS3ProfilesStore(store) - return &normalized, nil - } - - if !errors.Is(err, ErrSettingNotFound) { - return nil, fmt.Errorf("get sora s3 profiles: %w", err) - } - - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, legacyErr - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil -} - -func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { - if store == nil { - return fmt.Errorf("sora s3 profiles store cannot be nil") - } - - normalized := normalizeSoraS3ProfilesStore(*store) - data, err := json.Marshal(normalized) - if err != nil { - return fmt.Errorf("marshal sora s3 profiles: %w", err) - } - - updates := map[string]string{ - SettingKeySoraS3Profiles: string(data), - } - - active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) - if active == nil { - updates[SettingKeySoraS3Enabled] = "false" - updates[SettingKeySoraS3Endpoint] = "" - updates[SettingKeySoraS3Region] = "" - updates[SettingKeySoraS3Bucket] = "" - updates[SettingKeySoraS3AccessKeyID] = "" - updates[SettingKeySoraS3Prefix] = "" - updates[SettingKeySoraS3ForcePathStyle] = "false" - updates[SettingKeySoraS3CDNURL] = "" - updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" - updates[SettingKeySoraS3SecretAccessKey] = "" - } else { - updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) - updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) - updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) - updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) - updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) - updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) - updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) - updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) - updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) - updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey - } - - if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { - return err - } - - if s.onUpdate != nil { - s.onUpdate() - } - if s.onS3Update != nil { - s.onS3Update() - } - return nil -} - -func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - keys := []string{ - SettingKeySoraS3Enabled, - SettingKeySoraS3Endpoint, - SettingKeySoraS3Region, - SettingKeySoraS3Bucket, - SettingKeySoraS3AccessKeyID, - SettingKeySoraS3SecretAccessKey, - SettingKeySoraS3Prefix, - SettingKeySoraS3ForcePathStyle, - SettingKeySoraS3CDNURL, - SettingKeySoraDefaultStorageQuotaBytes, - } - - values, err := s.settingRepo.GetMultiple(ctx, keys) - if err != nil { - return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) - } - - result := &SoraS3Settings{ - Enabled: values[SettingKeySoraS3Enabled] == "true", - Endpoint: values[SettingKeySoraS3Endpoint], - Region: values[SettingKeySoraS3Region], - Bucket: values[SettingKeySoraS3Bucket], - AccessKeyID: values[SettingKeySoraS3AccessKeyID], - SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], - SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", - Prefix: values[SettingKeySoraS3Prefix], - ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", - CDNURL: values[SettingKeySoraS3CDNURL], - } - if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { - result.DefaultStorageQuotaBytes = v - } - return result, nil -} - -func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { - seen := make(map[string]struct{}, len(store.Items)) - normalized := soraS3ProfilesStore{ - ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), - Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), - } - now := time.Now().UTC().Format(time.RFC3339) - - for idx := range store.Items { - item := store.Items[idx] - item.ProfileID = strings.TrimSpace(item.ProfileID) - if item.ProfileID == "" { - item.ProfileID = fmt.Sprintf("profile-%d", idx+1) - } - if _, exists := seen[item.ProfileID]; exists { - continue - } - seen[item.ProfileID] = struct{}{} - - item.Name = strings.TrimSpace(item.Name) - if item.Name == "" { - item.Name = item.ProfileID - } - item.Endpoint = strings.TrimSpace(item.Endpoint) - item.Region = strings.TrimSpace(item.Region) - item.Bucket = strings.TrimSpace(item.Bucket) - item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) - item.Prefix = strings.TrimSpace(item.Prefix) - item.CDNURL = strings.TrimSpace(item.CDNURL) - item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) - item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) - if item.UpdatedAt == "" { - item.UpdatedAt = now - } - normalized.Items = append(normalized.Items, item) - } - - if len(normalized.Items) == 0 { - normalized.ActiveProfileID = "" - return normalized - } - - if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { - return normalized - } - - normalized.ActiveProfileID = normalized.Items[0].ProfileID - return normalized -} - -func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { - if store == nil { - return &SoraS3ProfileList{} - } - items := make([]SoraS3Profile, 0, len(store.Items)) - for idx := range store.Items { - item := store.Items[idx] - items = append(items, SoraS3Profile{ - ProfileID: item.ProfileID, - Name: item.Name, - IsActive: item.ProfileID == store.ActiveProfileID, - Enabled: item.Enabled, - Endpoint: item.Endpoint, - Region: item.Region, - Bucket: item.Bucket, - AccessKeyID: item.AccessKeyID, - SecretAccessKey: item.SecretAccessKey, - SecretAccessKeyConfigured: item.SecretAccessKey != "", - Prefix: item.Prefix, - ForcePathStyle: item.ForcePathStyle, - CDNURL: item.CDNURL, - DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, - UpdatedAt: item.UpdatedAt, - }) - } - return &SoraS3ProfileList{ - ActiveProfileID: store.ActiveProfileID, - Items: items, - } -} - -func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == activeProfileID { - return &items[idx] - } - } - if len(items) == 0 { - return nil - } - return &items[0] -} - -func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == profileID { - return &items[idx] - } - } - return nil -} - -func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { - for idx := range items { - if items[idx].ProfileID == activeProfileID { - return &items[idx] - } - } - if len(items) == 0 { - return nil - } - return &items[0] -} - -func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { - for idx := range items { - if items[idx].ProfileID == profileID { - return idx - } - } - return -1 -} - -func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { - return findSoraS3ProfileIndex(items, profileID) >= 0 -} - -func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { - if settings == nil { - return true - } - if settings.Enabled { - return false - } - if strings.TrimSpace(settings.Endpoint) != "" { - return false - } - if strings.TrimSpace(settings.Region) != "" { - return false - } - if strings.TrimSpace(settings.Bucket) != "" { - return false - } - if strings.TrimSpace(settings.AccessKeyID) != "" { - return false - } - if settings.SecretAccessKey != "" { - return false - } - if strings.TrimSpace(settings.Prefix) != "" { - return false - } - if strings.TrimSpace(settings.CDNURL) != "" { - return false - } - return settings.DefaultStorageQuotaBytes == 0 -} - -func maxInt64(value int64, min int64) int64 { - if value < min { - return min - } - return value -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 411939bb..473d7297 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -41,7 +41,6 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string - SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints @@ -107,7 +106,6 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string - SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints @@ -116,46 +114,6 @@ type PublicSettings struct { Version string } -// SoraS3Settings Sora S3 存储配置 -type SoraS3Settings struct { - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// SoraS3Profile Sora S3 多配置项(服务内部模型) -type SoraS3Profile struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// SoraS3ProfileList Sora S3 多配置列表 -type SoraS3ProfileList struct { - ActiveProfileID string `json:"active_profile_id"` - Items []SoraS3Profile `json:"items"` -} - // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 @@ -220,10 +178,13 @@ const ( // BetaPolicyRule 单条 Beta 策略规则 type BetaPolicyRule struct { - BetaToken string `json:"beta_token"` // beta token 值 - Action string `json:"action"` // "pass" | "filter" | "block" - Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" - ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) + BetaToken string `json:"beta_token"` // beta token 值 + Action string `json:"action"` // "pass" | "filter" | "block" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" + ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) + ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效) + FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式 + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效) } // BetaPolicySettings Beta 策略配置 diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go deleted file mode 100644 index eccc1acf..00000000 --- a/backend/internal/service/sora_account_service.go +++ /dev/null @@ -1,40 +0,0 @@ -package service - -import "context" - -// SoraAccountRepository Sora 账号扩展表仓储接口 -// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 -// -// 设计说明: -// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 -// - Sora gateway 优先读取此表的字段以获得更好的查询性能 -// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 -// - Token 刷新时需要同时更新两个表以保持数据一致性 -type SoraAccountRepository interface { - // Upsert 创建或更新 Sora 账号扩展信息 - // accountID: 关联的 accounts.id - // updates: 要更新的字段,支持 access_token、refresh_token、session_token - // - // 如果记录不存在则创建,存在则更新。 - // 用于: - // 1. 创建 Sora 账号时初始化扩展表 - // 2. Token 刷新时同步更新扩展表 - Upsert(ctx context.Context, accountID int64, updates map[string]any) error - - // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 - // 返回 nil, nil 表示记录不存在(非错误) - GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) - - // Delete 删除 Sora 账号扩展信息 - // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 - Delete(ctx context.Context, accountID int64) error -} - -// SoraAccount Sora 账号扩展信息 -// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 -type SoraAccount struct { - AccountID int64 // 关联的 accounts.id - AccessToken string // OAuth access_token - RefreshToken string // OAuth refresh_token - SessionToken string // Session token(可选,用于 ST→AT 兜底) -} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go deleted file mode 100644 index 0a914d2d..00000000 --- a/backend/internal/service/sora_client.go +++ /dev/null @@ -1,117 +0,0 @@ -package service - -import ( - "context" - "fmt" - "net/http" -) - -// SoraClient 定义直连 Sora 的任务操作接口。 -type SoraClient interface { - Enabled() bool - UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) - CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) - CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) - CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) - UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) - GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) - DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) - UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) - FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) - SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error - DeleteCharacter(ctx context.Context, account *Account, characterID string) error - PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) - DeletePost(ctx context.Context, account *Account, postID string) error - GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) - EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) - GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) - GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) -} - -// SoraImageRequest 图片生成请求参数 -type SoraImageRequest struct { - Prompt string - Width int - Height int - MediaID string -} - -// SoraVideoRequest 视频生成请求参数 -type SoraVideoRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - VideoCount int - MediaID string - RemixTargetID string - CameoIDs []string -} - -// SoraStoryboardRequest 分镜视频生成请求参数 -type SoraStoryboardRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - MediaID string -} - -// SoraImageTaskStatus 图片任务状态 -type SoraImageTaskStatus struct { - ID string - Status string - ProgressPct float64 - URLs []string - ErrorMsg string -} - -// SoraVideoTaskStatus 视频任务状态 -type SoraVideoTaskStatus struct { - ID string - Status string - ProgressPct int - URLs []string - GenerationID string - ErrorMsg string -} - -// SoraCameoStatus 角色处理中间态 -type SoraCameoStatus struct { - Status string - StatusMessage string - DisplayNameHint string - UsernameHint string - ProfileAssetURL string - InstructionSetHint any - InstructionSet any -} - -// SoraCharacterFinalizeRequest 角色定稿请求参数 -type SoraCharacterFinalizeRequest struct { - CameoID string - Username string - DisplayName string - ProfileAssetPointer string - InstructionSet any -} - -// SoraUpstreamError 上游错误 -type SoraUpstreamError struct { - StatusCode int - Message string - Headers http.Header - Body []byte -} - -func (e *SoraUpstreamError) Error() string { - if e == nil { - return "sora upstream error" - } - if e.Message != "" { - return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) - } - return fmt.Sprintf("sora upstream error: %d", e.StatusCode) -} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go deleted file mode 100644 index e9d325f4..00000000 --- a/backend/internal/service/sora_gateway_service.go +++ /dev/null @@ -1,1559 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "math" - "math/rand" - "mime" - "net" - "net/http" - "net/url" - "regexp" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/gin-gonic/gin" -) - -const soraImageInputMaxBytes = 20 << 20 -const soraImageInputMaxRedirects = 3 -const soraImageInputTimeout = 20 * time.Second -const soraVideoInputMaxBytes = 200 << 20 -const soraVideoInputMaxRedirects = 3 -const soraVideoInputTimeout = 60 * time.Second - -var soraImageSizeMap = map[string]string{ - "gpt-image": "360", - "gpt-image-landscape": "540", - "gpt-image-portrait": "540", -} - -var soraBlockedHostnames = map[string]struct{}{ - "localhost": {}, - "localhost.localdomain": {}, - "metadata.google.internal": {}, - "metadata.google.internal.": {}, -} - -var soraBlockedCIDRs = mustParseCIDRs([]string{ - "0.0.0.0/8", - "10.0.0.0/8", - "100.64.0.0/10", - "127.0.0.0/8", - "169.254.0.0/16", - "172.16.0.0/12", - "192.168.0.0/16", - "224.0.0.0/4", - "240.0.0.0/4", - "::/128", - "::1/128", - "fc00::/7", - "fe80::/10", -}) - -// SoraGatewayService handles forwarding requests to Sora upstream. -type SoraGatewayService struct { - soraClient SoraClient - rateLimitService *RateLimitService - httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 - cfg *config.Config -} - -type soraWatermarkOptions struct { - Enabled bool - ParseMethod string - ParseURL string - ParseToken string - FallbackOnFailure bool - DeletePost bool -} - -type soraCharacterOptions struct { - SetPublic bool - DeleteAfterGenerate bool -} - -type soraCharacterFlowResult struct { - CameoID string - CharacterID string - Username string - DisplayName string -} - -var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) -var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) -var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) -var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) - -type soraPreflightChecker interface { - PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error -} - -func NewSoraGatewayService( - soraClient SoraClient, - rateLimitService *RateLimitService, - httpUpstream HTTPUpstream, - cfg *config.Config, -) *SoraGatewayService { - return &SoraGatewayService{ - soraClient: soraClient, - rateLimitService: rateLimitService, - httpUpstream: httpUpstream, - cfg: cfg, - } -} - -func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { - startTime := time.Now() - - // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient - if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { - if s.httpUpstream == nil { - s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) - return nil, errors.New("httpUpstream not configured for sora apikey forwarding") - } - return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) - } - - if s.soraClient == nil || !s.soraClient.Enabled() { - if c != nil { - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "type": "api_error", - "message": "Sora 上游未配置", - }, - }) - } - return nil, errors.New("sora upstream not configured") - } - - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) - return nil, fmt.Errorf("parse request: %w", err) - } - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - if strings.TrimSpace(reqModel) == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) - return nil, errors.New("model is required") - } - originalModel := reqModel - - mappedModel := account.GetMappedModel(reqModel) - var upstreamModel string - if mappedModel != "" && mappedModel != reqModel { - reqModel = mappedModel - upstreamModel = mappedModel - } - - modelCfg, ok := GetSoraModelConfig(reqModel) - if !ok { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) - return nil, fmt.Errorf("unsupported model: %s", reqModel) - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) - prompt = strings.TrimSpace(prompt) - imageInput = strings.TrimSpace(imageInput) - videoInput = strings.TrimSpace(videoInput) - remixTargetID = strings.TrimSpace(remixTargetID) - - if videoInput != "" && modelCfg.Type != "video" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) - return nil, errors.New("video input only supports video models") - } - if videoInput != "" && imageInput != "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) - return nil, errors.New("image input and video input cannot be used together") - } - characterOnly := videoInput != "" && prompt == "" - if modelCfg.Type == "prompt_enhance" && prompt == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - - reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) - if cancel != nil { - defer cancel() - } - if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { - if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - } - - if modelCfg.Type == "prompt_enhance" { - enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - content := strings.TrimSpace(enhancedPrompt) - if content == "" { - content = prompt - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - - characterOpts := parseSoraCharacterOptions(reqBody) - watermarkOpts := parseSoraWatermarkOptions(reqBody) - var characterResult *soraCharacterFlowResult - if videoInput != "" { - videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) - if videoErr != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) - return nil, videoErr - } - characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) - if videoErr != nil { - return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) - } - if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { - characterID := strings.TrimSpace(characterResult.CharacterID) - defer func() { - cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) - defer cancelCleanup() - if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { - log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) - } - }() - } - if characterOnly { - content := "角色创建成功" - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - resp := buildSoraNonStreamResponse(content, reqModel) - if characterResult != nil { - resp["character_id"] = characterResult.CharacterID - resp["cameo_id"] = characterResult.CameoID - resp["character_username"] = characterResult.Username - resp["character_display_name"] = characterResult.DisplayName - } - c.JSON(http.StatusOK, resp) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) - } - } - - var imageData []byte - imageFilename := "" - if imageInput != "" { - decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) - if err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) - return nil, err - } - imageData = decoded - imageFilename = filename - } - - mediaID := "" - if len(imageData) > 0 { - uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - mediaID = uploadID - } - - taskID := "" - var err error - videoCount := parseSoraVideoCount(reqBody) - switch modelCfg.Type { - case "image": - taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ - Prompt: prompt, - Width: modelCfg.Width, - Height: modelCfg.Height, - MediaID: mediaID, - }) - case "video": - if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { - taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ - Prompt: formatSoraStoryboardPrompt(prompt), - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - MediaID: mediaID, - }) - } else { - taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ - Prompt: prompt, - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - VideoCount: videoCount, - MediaID: mediaID, - RemixTargetID: remixTargetID, - CameoIDs: extractSoraCameoIDs(reqBody), - }) - } - default: - err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) - } - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - - if clientStream && c != nil { - s.prepareSoraStream(c, taskID) - } - - var mediaURLs []string - videoGenerationID := "" - mediaType := modelCfg.Type - imageCount := 0 - imageSize := "" - switch modelCfg.Type { - case "image": - urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - mediaURLs = urls - imageCount = len(urls) - imageSize = soraImageSizeFromModel(reqModel) - case "video": - videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - if videoStatus != nil { - mediaURLs = videoStatus.URLs - videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) - } - default: - mediaType = "prompt" - } - - watermarkPostID := "" - if modelCfg.Type == "video" && watermarkOpts.Enabled { - watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) - if watermarkErr != nil { - if !watermarkOpts.FallbackOnFailure { - return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) - } - log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) - } else if strings.TrimSpace(watermarkURL) != "" { - mediaURLs = []string{strings.TrimSpace(watermarkURL)} - watermarkPostID = strings.TrimSpace(postID) - } - } - - // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 - // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 - finalURLs := s.normalizeSoraMediaURLs(mediaURLs) - if watermarkPostID != "" && watermarkOpts.DeletePost { - if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { - log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) - } - } - - content := buildSoraContent(mediaType, finalURLs) - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - response := buildSoraNonStreamResponse(content, reqModel) - if len(finalURLs) > 0 { - response["media_url"] = finalURLs[0] - if len(finalURLs) > 1 { - response["media_urls"] = finalURLs - } - } - c.JSON(http.StatusOK, response) - } - - return &ForwardResult{ - RequestID: taskID, - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} - -func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { - if s == nil || s.cfg == nil { - return ctx, nil - } - timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds - if stream { - timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds - } - if timeoutSeconds <= 0 { - return ctx, nil - } - return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) -} - -func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { - opts := soraWatermarkOptions{ - Enabled: parseBoolWithDefault(body, "watermark_free", false), - ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), - ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), - ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), - FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), - DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), - } - if opts.ParseMethod == "" { - opts.ParseMethod = "third_party" - } - return opts -} - -func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { - return soraCharacterOptions{ - SetPublic: parseBoolWithDefault(body, "character_set_public", true), - DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), - } -} - -func parseSoraVideoCount(body map[string]any) int { - if body == nil { - return 1 - } - keys := []string{"video_count", "videos", "n_variants"} - for _, key := range keys { - count := parseIntWithDefault(body, key, 0) - if count > 0 { - return clampInt(count, 1, 3) - } - } - return 1 -} - -func parseBoolWithDefault(body map[string]any, key string, def bool) bool { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case bool: - return typed - case int: - return typed != 0 - case int32: - return typed != 0 - case int64: - return typed != 0 - case float64: - return typed != 0 - case string: - typed = strings.ToLower(strings.TrimSpace(typed)) - if typed == "true" || typed == "1" || typed == "yes" { - return true - } - if typed == "false" || typed == "0" || typed == "no" { - return false - } - } - return def -} - -func parseStringWithDefault(body map[string]any, key, def string) string { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - if str, ok := val.(string); ok { - return str - } - return def -} - -func parseIntWithDefault(body map[string]any, key string, def int) int { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case int: - return typed - case int32: - return int(typed) - case int64: - return int(typed) - case float64: - return int(typed) - case string: - parsed, err := strconv.Atoi(strings.TrimSpace(typed)) - if err == nil { - return parsed - } - } - return def -} - -func clampInt(v, minVal, maxVal int) int { - if v < minVal { - return minVal - } - if v > maxVal { - return maxVal - } - return v -} - -func extractSoraCameoIDs(body map[string]any) []string { - if body == nil { - return nil - } - raw, ok := body["cameo_ids"] - if !ok { - return nil - } - switch typed := raw.(type) { - case []string: - out := make([]string, 0, len(typed)) - for _, item := range typed { - item = strings.TrimSpace(item) - if item != "" { - out = append(out, item) - } - } - return out - case []any: - out := make([]string, 0, len(typed)) - for _, item := range typed { - str, ok := item.(string) - if !ok { - continue - } - str = strings.TrimSpace(str) - if str != "" { - out = append(out, str) - } - } - return out - default: - return nil - } -} - -func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { - cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) - if err != nil { - return nil, err - } - - cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) - if err != nil { - return nil, err - } - username := processSoraCharacterUsername(cameoStatus.UsernameHint) - displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) - if displayName == "" { - displayName = "Character" - } - profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) - if profileAssetURL == "" { - return nil, errors.New("profile asset url not found in cameo status") - } - - avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) - if err != nil { - return nil, err - } - assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) - if err != nil { - return nil, err - } - instructionSet := cameoStatus.InstructionSetHint - if instructionSet == nil { - instructionSet = cameoStatus.InstructionSet - } - - characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ - CameoID: strings.TrimSpace(cameoID), - Username: username, - DisplayName: displayName, - ProfileAssetPointer: assetPointer, - InstructionSet: instructionSet, - }) - if err != nil { - return nil, err - } - - if opts.SetPublic { - if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { - return nil, err - } - } - - return &soraCharacterFlowResult{ - CameoID: strings.TrimSpace(cameoID), - CharacterID: strings.TrimSpace(characterID), - Username: strings.TrimSpace(username), - DisplayName: displayName, - }, nil -} - -func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - timeout := 10 * time.Minute - interval := 5 * time.Second - maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) - if maxAttempts < 1 { - maxAttempts = 1 - } - - var lastErr error - consecutiveErrors := 0 - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) - if err != nil { - lastErr = err - consecutiveErrors++ - if consecutiveErrors >= 3 { - break - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - consecutiveErrors = 0 - if status == nil { - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) - statusMessage := strings.TrimSpace(status.StatusMessage) - if currentStatus == "failed" { - if statusMessage == "" { - statusMessage = "character creation failed" - } - return nil, errors.New(statusMessage) - } - if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { - return status, nil - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - } - if lastErr != nil { - return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) - } - return nil, errors.New("cameo processing timeout") -} - -func processSoraCharacterUsername(usernameHint string) string { - usernameHint = strings.TrimSpace(usernameHint) - if usernameHint == "" { - usernameHint = "character" - } - if strings.Contains(usernameHint, ".") { - parts := strings.Split(usernameHint, ".") - usernameHint = strings.TrimSpace(parts[len(parts)-1]) - } - if usernameHint == "" { - usernameHint = "character" - } - return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) -} - -func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { - generationID = strings.TrimSpace(generationID) - if generationID == "" { - return "", "", errors.New("generation id is required for watermark-free mode") - } - postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) - if err != nil { - return "", "", err - } - postID = strings.TrimSpace(postID) - if postID == "" { - return "", "", errors.New("watermark-free publish returned empty post id") - } - - switch opts.ParseMethod { - case "custom": - urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) - if parseErr != nil { - return "", postID, parseErr - } - return strings.TrimSpace(urlVal), postID, nil - case "", "third_party": - return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil - default: - return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) - } -} - -func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 402, 403, 404, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -func buildSoraNonStreamResponse(content, model string) map[string]any { - return map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "message": map[string]any{ - "role": "assistant", - "content": content, - }, - "finish_reason": "stop", - }, - }, - } -} - -func soraImageSizeFromModel(model string) string { - modelLower := strings.ToLower(model) - if size, ok := soraImageSizeMap[modelLower]; ok { - return size - } - if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { - return "540" - } - return "360" -} - -func soraProErrorMessage(model, upstreamMsg string) string { - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "sora2pro-hd") { - return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" - } - if strings.Contains(modelLower, "sora2pro") { - return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" - } - return "" -} - -func firstMediaURL(urls []string) string { - if len(urls) == 0 { - return "" - } - return urls[0] -} - -func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { - if path == "" { - return path - } - prefix := "/sora/media" - values := url.Values{} - if rawQuery != "" { - if parsed, err := url.ParseQuery(rawQuery); err == nil { - values = parsed - } - } - - signKey := "" - ttlSeconds := 0 - if s != nil && s.cfg != nil { - signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) - ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds - } - values.Del("sig") - values.Del("expires") - signingQuery := values.Encode() - if signKey != "" && ttlSeconds > 0 { - expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() - signature := SignSoraMediaURL(path, signingQuery, expires, signKey) - if signature != "" { - values.Set("expires", strconv.FormatInt(expires, 10)) - values.Set("sig", signature) - prefix = "/sora/media-signed" - } - } - - encoded := values.Encode() - if encoded == "" { - return prefix + path - } - return prefix + path + "?" + encoded -} - -func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { - if c == nil { - return - } - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - if strings.TrimSpace(requestID) != "" { - c.Header("x-request-id", requestID) - } -} - -func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { - if c == nil { - return nil, nil - } - writer := c.Writer - flusher, _ := writer.(http.Flusher) - - chunk := map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{ - "content": content, - }, - }, - }, - } - encoded, _ := jsonMarshalRaw(chunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { - return nil, err - } - if flusher != nil { - flusher.Flush() - } - ms := int(time.Since(startTime).Milliseconds()) - finalChunk := map[string]any{ - "id": chunk["id"], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{}, - "finish_reason": "stop", - }, - }, - } - finalEncoded, _ := jsonMarshalRaw(finalChunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { - return &ms, err - } - if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { - return &ms, err - } - if flusher != nil { - flusher.Flush() - } - return &ms, nil -} - -func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { - if c == nil { - return - } - if stream { - flusher, _ := c.Writer.(http.Flusher) - errorData := map[string]any{ - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) - _, _ = fmt.Fprint(c.Writer, errorEvent) - _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - return - } - c.JSON(status, gin.H{ - "error": gin.H{ - "type": errType, - "message": message, - }, - }) -} - -func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { - if err == nil { - return nil - } - var upstreamErr *SoraUpstreamError - if errors.As(err, &upstreamErr) { - accountID := int64(0) - if account != nil { - accountID = account.ID - } - logger.LegacyPrintf( - "service.sora", - "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", - accountID, - model, - upstreamErr.StatusCode, - strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), - strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), - strings.TrimSpace(upstreamErr.Message), - truncateForLog(upstreamErr.Body, 1024), - ) - if s.rateLimitService != nil && account != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) - } - if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - var responseHeaders http.Header - if upstreamErr.Headers != nil { - responseHeaders = upstreamErr.Headers.Clone() - } - return &UpstreamFailoverError{ - StatusCode: upstreamErr.StatusCode, - ResponseBody: upstreamErr.Body, - ResponseHeaders: responseHeaders, - } - } - msg := upstreamErr.Message - if override := soraProErrorMessage(model, msg); override != "" { - msg = override - } - s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) - return err - } - if errors.Is(err, context.DeadlineExceeded) { - s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) - return err - } - s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) - return err -} - -func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetImageTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "succeeded", "completed": - return status.URLs, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora image generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora image generation timeout") -} - -func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetVideoTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "completed", "succeeded": - return status, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora video generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora video generation timeout") -} - -func (s *SoraGatewayService) pollInterval() time.Duration { - if s == nil || s.cfg == nil { - return 2 * time.Second - } - interval := s.cfg.Sora.Client.PollIntervalSeconds - if interval <= 0 { - interval = 2 - } - return time.Duration(interval) * time.Second -} - -func (s *SoraGatewayService) pollMaxAttempts() int { - if s == nil || s.cfg == nil { - return 600 - } - maxAttempts := s.cfg.Sora.Client.MaxPollAttempts - if maxAttempts <= 0 { - maxAttempts = 600 - } - return maxAttempts -} - -func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { - if c == nil { - return - } - interval := 10 * time.Second - if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { - interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second - } - if time.Since(*lastPing) < interval { - return - } - if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - *lastPing = time.Now() - } -} - -func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { - if len(urls) == 0 { - return urls - } - output := make([]string, 0, len(urls)) - for _, raw := range urls { - raw = strings.TrimSpace(raw) - if raw == "" { - continue - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - output = append(output, raw) - continue - } - pathVal := raw - if !strings.HasPrefix(pathVal, "/") { - pathVal = "/" + pathVal - } - output = append(output, s.buildSoraMediaURL(pathVal, "")) - } - return output -} - -// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, -// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 -func jsonMarshalRaw(v any) ([]byte, error) { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - if err := enc.Encode(v); err != nil { - return nil, err - } - // Encode 会追加换行符,去掉它 - b := buf.Bytes() - if len(b) > 0 && b[len(b)-1] == '\n' { - b = b[:len(b)-1] - } - return b, nil -} - -func buildSoraContent(mediaType string, urls []string) string { - switch mediaType { - case "image": - parts := make([]string, 0, len(urls)) - for _, u := range urls { - parts = append(parts, fmt.Sprintf("![image](%s)", u)) - } - return strings.Join(parts, "\n") - case "video": - if len(urls) == 0 { - return "" - } - return fmt.Sprintf("```html\n\n```", urls[0]) - default: - return "" - } -} - -func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { - if body == nil { - return "", "", "", "" - } - if v, ok := body["remix_target_id"].(string); ok { - remixTargetID = strings.TrimSpace(v) - } - if v, ok := body["image"].(string); ok { - imageInput = v - } - if v, ok := body["video"].(string); ok { - videoInput = v - } - if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { - prompt = v - } - if messages, ok := body["messages"].([]any); ok { - builder := strings.Builder{} - for _, raw := range messages { - msg, ok := raw.(map[string]any) - if !ok { - continue - } - role, _ := msg["role"].(string) - if role != "" && role != "user" { - continue - } - content := msg["content"] - text, img, vid := parseSoraMessageContent(content) - if text != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(text) - } - if imageInput == "" && img != "" { - imageInput = img - } - if videoInput == "" && vid != "" { - videoInput = vid - } - } - if prompt == "" { - prompt = builder.String() - } - } - if remixTargetID == "" { - remixTargetID = extractRemixTargetIDFromPrompt(prompt) - } - prompt = cleanRemixLinkFromPrompt(prompt) - return prompt, imageInput, videoInput, remixTargetID -} - -func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { - switch val := content.(type) { - case string: - return val, "", "" - case []any: - builder := strings.Builder{} - for _, item := range val { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := itemMap["type"].(string) - switch t { - case "text": - if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(txt) - } - case "image_url": - if imageInput == "" { - if urlVal, ok := itemMap["image_url"].(map[string]any); ok { - imageInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["image_url"].(string); ok { - imageInput = urlStr - } - } - case "video_url": - if videoInput == "" { - if urlVal, ok := itemMap["video_url"].(map[string]any); ok { - videoInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["video_url"].(string); ok { - videoInput = urlStr - } - } - } - } - return builder.String(), imageInput, videoInput - default: - return "", "", "" - } -} - -func isSoraStoryboardPrompt(prompt string) bool { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return false - } - return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 -} - -func formatSoraStoryboardPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) - if len(matches) == 0 { - return prompt - } - firstBracketPos := strings.Index(prompt, "[") - instructions := "" - if firstBracketPos > 0 { - instructions = strings.TrimSpace(prompt[:firstBracketPos]) - } - shots := make([]string, 0, len(matches)) - for i, match := range matches { - if len(match) < 3 { - continue - } - duration := strings.TrimSpace(match[1]) - scene := strings.TrimSpace(match[2]) - if scene == "" { - continue - } - shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) - } - if len(shots) == 0 { - return prompt - } - timeline := strings.Join(shots, "\n\n") - if instructions == "" { - return timeline - } - return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) -} - -func extractRemixTargetIDFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) -} - -func cleanRemixLinkFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return prompt - } - cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") - cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") - cleaned = strings.Join(strings.Fields(cleaned), " ") - return strings.TrimSpace(cleaned) -} - -func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, "", errors.New("empty image input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, "", errors.New("invalid data url") - } - meta := parts[0] - payload := parts[1] - decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) - if err != nil { - return nil, "", err - } - ext := "" - if strings.HasPrefix(meta, "data:") { - metaParts := strings.SplitN(meta[5:], ";", 2) - if len(metaParts) > 0 { - if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { - ext = exts[0] - } - } - } - filename := "image" + ext - return decoded, filename, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraImageInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) - if err != nil { - return nil, "", errors.New("invalid base64 image") - } - return decoded, "image.png", nil -} - -func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, errors.New("empty video input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, errors.New("invalid video data url") - } - decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraVideoInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil -} - -func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, "", err - } - client := &http.Client{ - Timeout: soraImageInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraImageInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, "", err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) - if err != nil { - return nil, "", err - } - ext := fileExtFromURL(parsed.String()) - if ext == "" { - ext = fileExtFromContentType(resp.Header.Get("Content-Type")) - } - filename := "image" + ext - return data, filename, nil -} - -func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, err - } - client := &http.Client{ - Timeout: soraVideoInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraVideoInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, errors.New("empty video content") - } - return data, nil -} - -func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { - if maxBytes <= 0 { - return nil, errors.New("invalid max bytes limit") - } - decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) - limited := io.LimitReader(decoder, maxBytes+1) - data, err := io.ReadAll(limited) - if err != nil { - return nil, err - } - if int64(len(data)) > maxBytes { - return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) - } - return data, nil -} - -func validateSoraRemoteURL(raw string) (*url.URL, error) { - if strings.TrimSpace(raw) == "" { - return nil, errors.New("empty remote url") - } - parsed, err := url.Parse(raw) - if err != nil { - return nil, fmt.Errorf("invalid remote url: %w", err) - } - if err := validateSoraRemoteURLValue(parsed); err != nil { - return nil, err - } - return parsed, nil -} - -func validateSoraRemoteURLValue(parsed *url.URL) error { - if parsed == nil { - return errors.New("invalid remote url") - } - scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) - if scheme != "http" && scheme != "https" { - return errors.New("only http/https remote url is allowed") - } - if parsed.User != nil { - return errors.New("remote url cannot contain userinfo") - } - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - if host == "" { - return errors.New("remote url missing host") - } - if _, blocked := soraBlockedHostnames[host]; blocked { - return errors.New("remote url is not allowed") - } - if ip := net.ParseIP(host); ip != nil { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - return nil - } - ips, err := net.LookupIP(host) - if err != nil { - return fmt.Errorf("resolve remote url failed: %w", err) - } - for _, ip := range ips { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - } - return nil -} - -func isSoraBlockedIP(ip net.IP) bool { - if ip == nil { - return true - } - for _, cidr := range soraBlockedCIDRs { - if cidr.Contains(ip) { - return true - } - } - return false -} - -func mustParseCIDRs(values []string) []*net.IPNet { - out := make([]*net.IPNet, 0, len(values)) - for _, val := range values { - _, cidr, err := net.ParseCIDR(val) - if err != nil { - continue - } - out = append(out, cidr) - } - return out -} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go deleted file mode 100644 index 2fef600c..00000000 --- a/backend/internal/service/sora_gateway_service_test.go +++ /dev/null @@ -1,564 +0,0 @@ -//go:build unit - -package service - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -var _ SoraClient = (*stubSoraClientForPoll)(nil) - -type stubSoraClientForPoll struct { - imageStatus *SoraImageTaskStatus - videoStatus *SoraVideoTaskStatus - imageCalls int - videoCalls int - enhanced string - enhanceErr error - storyboard bool - videoReq SoraVideoRequest - parseErr error - postCalls int - deleteCalls int -} - -func (s *stubSoraClientForPoll) Enabled() bool { return true } -func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { - return "", nil -} -func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { - s.videoReq = req - return "task-video", nil -} -func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { - s.storyboard = true - return "task-video", nil -} -func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { - return "cameo-1", nil -} -func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - return &SoraCameoStatus{ - Status: "finalized", - StatusMessage: "Completed", - DisplayNameHint: "Character", - UsernameHint: "user.character", - ProfileAssetURL: "https://example.com/avatar.webp", - }, nil -} -func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { - return []byte("avatar"), nil -} -func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { - return "asset-pointer", nil -} -func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { - return "character-1", nil -} -func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { - return nil -} -func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { - return nil -} -func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { - s.postCalls++ - return "s_post", nil -} -func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { - s.deleteCalls++ - return nil -} -func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { - if s.parseErr != nil { - return "", s.parseErr - } - return "https://example.com/no-watermark.mp4", nil -} -func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { - if s.enhanced != "" { - return s.enhanced, s.enhanceErr - } - return "enhanced prompt", s.enhanceErr -} -func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { - s.imageCalls++ - return s.imageStatus, nil -} -func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { - s.videoCalls++ - return s.videoStatus, nil -} - -func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { - client := &stubSoraClientForPoll{ - imageStatus: &SoraImageTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/a.png"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) - require.NoError(t, err) - require.Equal(t, []string{"https://example.com/a.png"}, urls) - require.Equal(t, 1, client.imageCalls) -} - -func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { - client := &stubSoraClientForPoll{ - enhanced: "cinematic prompt", - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ - ID: 1, - Platform: PlatformSora, - Status: StatusActive, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "prompt-enhance-short-10s": "prompt-enhance-short-15s", - }, - }, - } - body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, "prompt-enhance-short-10s", result.Model) - require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) -} - -func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.True(t, client.storyboard) -} - -func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 3, client.videoReq.VideoCount) -} - -func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { - client := &stubSoraClientForPoll{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, 0, client.videoCalls) -} - -func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - parseErr: errors.New("parse failed"), - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/original.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 0, client.deleteCalls) -} - -func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 1, client.deleteCalls) -} - -func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "reject", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) - require.Error(t, err) - require.Nil(t, status) - require.Contains(t, err.Error(), "reject") - require.Equal(t, 1, client.videoCalls) -} - -func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { - cfg := &config.Config{ - Gateway: config.GatewayConfig{ - SoraMediaSigningKey: "test-key", - SoraMediaSignedURLTTLSeconds: 600, - }, - } - service := NewSoraGatewayService(nil, nil, nil, cfg) - - url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") - require.Contains(t, url, "/sora/media-signed") - require.Contains(t, url, "expires=") - require.Contains(t, url, "sig=") -} - -func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - result := svc.normalizeSoraMediaURLs(nil) - require.Empty(t, result) - - result = svc.normalizeSoraMediaURLs([]string{}) - require.Empty(t, result) -} - -func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Equal(t, urls, result) -} - -func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { - cfg := &config.Config{} - svc := NewSoraGatewayService(nil, nil, nil, cfg) - urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) - require.Contains(t, result[0], "/sora/media") - require.Contains(t, result[1], "/sora/media") -} - -func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) -} - -func TestBuildSoraContent_Image(t *testing.T) { - content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) - require.Contains(t, content, "![image](https://a.com/1.png)") - require.Contains(t, content, "![image](https://a.com/2.png)") -} - -func TestBuildSoraContent_Video(t *testing.T) { - content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) - require.Contains(t, content, "