diff --git a/README.md b/README.md index 5415ea61..cb9d36ed 100644 --- a/README.md +++ b/README.md @@ -103,9 +103,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot pateway -Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line. -Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information. -Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150. +Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium API relay built for heavy AI developers, offering the full Claude and Codex series sourced 100% from official providers, with transparent token-level billing. Enterprise plans include high concurrency, dedicated management, contracts, and invoicing. Register now to get $3 in trial credits, top-ups from 60% off, and referral bonuses up to $150. @@ -120,6 +118,18 @@ Register now via this link to recei + +unity2 +Thanks to Unity2 for sponsoring this project! Unity2 is a high-performance AI model API relay for individuals, teams, and enterprises, handling 30B+ tokens/day with 5000 RPM concurrency. One API Key works across Claude Code, Codex, OpenAI models, IDE plugins, and Agent workflows, with balance billing, bundled subscriptions, enterprise invoicing, and 1-on-1 support. Register to claim $2 in balance, plus $10 more by joining the official group — up to $12 in free credit. + + + + +veilx +Thanks to Veilx for sponsoring this project! Veilx CDN is purpose-built for large-scale AI API traffic, deeply optimized for relay services and call chains across OpenAI, Claude, Gemini, and scenarios like chat, image generation, embeddings, and streaming — delivering lower latency and higher stability under heavy concurrency. It also offers China three-network optimized return lines, making it ideal for global AI relay platforms, overseas AI SaaS, and cross-border high-concurrency deployments. + + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index ca7b3218..6c9ff372 100644 --- a/README_CN.md +++ b/README_CN.md @@ -119,6 +119,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 + +unity2 +感谢 Unity2 赞助本项目! Unity2 是面向个人开发者、团队、企业的高性能 AI 模型 API 中转平台,长期服务国内头部企业,日均承载超 300 亿 token 调用,支持 5000 RPM 级高并发。一个 API Key 即可适配 Claude Code、Codex、OpenAI 模型、IDE 插件和 Agent 工作流等场景。具备企业级稳定供应能力,在高并发、持续调用和团队集中采购场景下依然保持低延迟、高可用。同时支持余额计费、组合订阅、首充优惠、企业开票、专属 1v1 对接,适合个人高频使用和企业长期接入。现在注册 Unity2.ai 可领取 $2 余额,加入官方群再送 $10 余额,合计最高可领 $12 免费额度,适合先体验后长期使用。注册链接 + + + + +veilx +感谢 Veilx 赞助本项目! Veilx CDN 专为超大规模 API 请求场景打造,针对 AI 中转站业务与 AI API 调用链路进行了深度优化,轻松应对高并发、高频请求与大流量传输,为开发者与企业提供更快、更稳、更低延迟的加速体验。无论是 OpenAI、Claude、Gemini 等 AI 接口中转,还是聊天、绘图、Embedding、流式输出等复杂场景,Veilx 都能显著提升响应速度与连接稳定性,有效降低网络波动带来的超时与失败问题。同时,Veilx 提供中国三网优化回国极速线路,大幅提升中国大陆地区访问海外 AI 服务的速度与稳定性,特别适合全球 AI 中转平台、海外 AI SaaS、跨境业务与高并发 API 系统部署。专为 AI API 而生,让你的 AI 中转服务更快、更稳、更省心。购买地址 + + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index 45adfd65..e3453d5e 100644 --- a/README_JA.md +++ b/README_JA.md @@ -119,6 +119,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを + +unity2 +Unity2 のご支援に感謝します!Unity2 は個人開発者、チーム、企業向けの高性能 AI モデル API 中継プラットフォームです。中国の大手企業に長期にわたりサービスを提供しており、1日あたり 300 億以上のトークン呼び出しを処理し、5000 RPM 級の高並列性をサポートします。1つの API キーで Claude Code、Codex、OpenAI モデル、IDE プラグイン、Agent ワークフローなど様々なシナリオに対応できます。エンタープライズグレードの安定供給能力を備え、高並列・継続的な呼び出し・チームの集中購入シーンでも低レイテンシと高可用性を維持します。残高課金、組み合わせサブスクリプション、初回チャージ特典、企業向け請求書発行、専属 1v1 サポートにも対応しており、個人の頻繁な利用にも企業の長期導入にも適しています。今 Unity2.ai に登録すると $2 の残高、公式グループに参加するとさらに $10 の残高がもらえ、合計最大 $12 の無料クレジットを獲得できます — 試用後に長期利用したい方に最適です。登録リンク + + + + +veilx +Veilx のご支援に感謝します!Veilx CDN は超大規模 API リクエストシナリオ向けに設計されており、AI 中継サービスと AI API 呼び出しチェーンに対して深く最適化されています。高並列・高頻度リクエスト・大容量トラフィックに容易に対応し、開発者と企業により高速で安定した、低レイテンシの加速体験を提供します。OpenAI、Claude、Gemini などの AI インターフェース中継はもちろん、チャット、画像生成、Embedding、ストリーミング出力などの複雑なシナリオでも、Veilx は応答速度と接続安定性を大幅に向上させ、ネットワーク変動によるタイムアウトや失敗を効果的に削減します。さらに、Veilx は中国三大ネットワーク最適化の高速回線を提供しており、中国本土から海外 AI サービスへのアクセス速度と安定性を大幅に向上させます。グローバル AI 中継プラットフォーム、海外 AI SaaS、越境ビジネス、高並列 API システム展開に特に適しています。AI API のために生まれ、あなたの AI 中継サービスをより速く、より安定して、より安心に。購入リンク + + + ## エコシステム diff --git a/assets/partners/logos/unity2.png b/assets/partners/logos/unity2.png new file mode 100644 index 00000000..f1da2ed1 Binary files /dev/null and b/assets/partners/logos/unity2.png differ diff --git a/assets/partners/logos/veilx.png b/assets/partners/logos/veilx.png new file mode 100644 index 00000000..33a37883 Binary files /dev/null and b/assets/partners/logos/veilx.png differ diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 9386678d..fc8a00c4 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index d0c4abc6..7b9dfc4d 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.130 +0.1.132 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 76fca0df..6c93de50 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyRepository := repository.NewAPIKeyRepository(client, db) userRPMCache := repository.NewUserRPMCache(redisClient) userGroupRateRepository := repository.NewUserGroupRateRepository(db) - billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) + userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client) + serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository) + billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) @@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) affiliateRepository := repository.NewAffiliateRepository(client, db) affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService) - authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService) @@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService) - userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService) - adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) rpmCache := repository.NewRPMCache(redisClient) groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) @@ -194,7 +196,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) digestSessionStore := service.NewDigestSessionStore() rpmTokenBucketService := service.NewRPMTokenBucketService() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 3e2bd7c5..de9d789a 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) diff --git a/backend/ent/client.go b/backend/ent/client.go index df20ddfa..06b53c78 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -48,6 +48,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" stdsql "database/sql" @@ -124,6 +125,8 @@ type Client struct { UserAttributeDefinition *UserAttributeDefinitionClient // UserAttributeValue is the client for interacting with the UserAttributeValue builders. UserAttributeValue *UserAttributeValueClient + // UserPlatformQuota is the client for interacting with the UserPlatformQuota builders. + UserPlatformQuota *UserPlatformQuotaClient // UserSubscription is the client for interacting with the UserSubscription builders. UserSubscription *UserSubscriptionClient } @@ -170,6 +173,7 @@ func (c *Client) init() { c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config) c.UserAttributeValue = NewUserAttributeValueClient(c.config) + c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config) c.UserSubscription = NewUserSubscriptionClient(c.config) } @@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), UserAttributeValue: NewUserAttributeValueClient(cfg), + UserPlatformQuota: NewUserPlatformQuotaClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), UserAttributeValue: NewUserAttributeValueClient(cfg), + UserPlatformQuota: NewUserPlatformQuotaClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) { c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.UserPlatformQuota, c.UserSubscription, } { n.Use(hooks...) } @@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) { c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.UserPlatformQuota, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.UserAttributeDefinition.mutate(ctx, m) case *UserAttributeValueMutation: return c.UserAttributeValue.mutate(ctx, m) + case *UserPlatformQuotaMutation: + return c.UserPlatformQuota.mutate(ctx, m) case *UserSubscriptionMutation: return c.UserSubscription.mutate(ctx, m) default: @@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery return query } +// QueryPlatformQuotas queries the platform_quotas edge of a User. +func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery { + query := (&UserPlatformQuotaClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV } } +// UserPlatformQuotaClient is a client for the UserPlatformQuota schema. +type UserPlatformQuotaClient struct { + config +} + +// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config. +func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient { + return &UserPlatformQuotaClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`. +func (c *UserPlatformQuotaClient) Use(hooks ...Hook) { + c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`. +func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) { + c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...) +} + +// Create returns a builder for creating a UserPlatformQuota entity. +func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate { + mutation := newUserPlatformQuotaMutation(c.config, OpCreate) + return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities. +func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk { + return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UserPlatformQuotaCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdate) + return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m)) + return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id)) + return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete { + mutation := newUserPlatformQuotaMutation(c.config, OpDelete) + return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne { + builder := c.Delete().Where(userplatformquota.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UserPlatformQuotaDeleteOne{builder} +} + +// Query returns a query builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery { + return &UserPlatformQuotaQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUserPlatformQuota}, + inters: c.Interceptors(), + } +} + +// Get returns a UserPlatformQuota entity by its id. +func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) { + return c.Query().Where(userplatformquota.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a UserPlatformQuota. +func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *UserPlatformQuotaClient) Hooks() []Hook { + hooks := c.hooks.UserPlatformQuota + return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *UserPlatformQuotaClient) Interceptors() []Interceptor { + inters := c.inters.UserPlatformQuota + return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...) +} + +func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op()) + } +} + // UserSubscriptionClient is a client for the UserSubscription schema. type UserSubscriptionClient struct { config @@ -6025,7 +6200,8 @@ type ( PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + UserAttributeDefinition, UserAttributeValue, UserPlatformQuota, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, @@ -6035,7 +6211,8 @@ type ( PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + UserAttributeDefinition, UserAttributeValue, UserPlatformQuota, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index c9fcc314..33d36e70 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -45,6 +45,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -139,6 +140,7 @@ func checkColumn(t, c string) error { userallowedgroup.Table: userallowedgroup.ValidColumn, userattributedefinition.Table: userattributedefinition.ValidColumn, userattributevalue.Table: userattributevalue.ValidColumn, + userplatformquota.Table: userplatformquota.ValidColumn, usersubscription.Table: usersubscription.ValidColumn, }) }) diff --git a/backend/ent/group.go b/backend/ent/group.go index a4f52c73..298df88a 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -85,6 +85,8 @@ type Group struct { DefaultMappedModel string `json:"default_mapped_model,omitempty"` // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + // 自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度 + ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config,omitempty"` // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流 RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -193,7 +195,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: + case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig, group.FieldModelsListConfig: values[i] = new([]byte) case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) @@ -440,6 +442,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) } } + case group.FieldModelsListConfig: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field models_list_config", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ModelsListConfig); err != nil { + return fmt.Errorf("unmarshal field models_list_config: %w", err) + } + } case group.FieldRpmLimit: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) @@ -641,6 +651,9 @@ func (_m *Group) String() string { builder.WriteString("messages_dispatch_model_config=") builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) builder.WriteString(", ") + builder.WriteString("models_list_config=") + builder.WriteString(fmt.Sprintf("%v", _m.ModelsListConfig)) + builder.WriteString(", ") builder.WriteString("rpm_limit=") builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 4e9ba6b6..ebe9bd7e 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -82,6 +82,8 @@ const ( FieldDefaultMappedModel = "default_mapped_model" // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" + // FieldModelsListConfig holds the string denoting the models_list_config field in the database. + FieldModelsListConfig = "models_list_config" // FieldRpmLimit holds the string denoting the rpm_limit field in the database. FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. @@ -192,6 +194,7 @@ var Columns = []string{ FieldRequirePrivacySet, FieldDefaultMappedModel, FieldMessagesDispatchModelConfig, + FieldModelsListConfig, FieldRpmLimit, } @@ -276,6 +279,8 @@ var ( DefaultMappedModelValidator func(string) error // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig + // DefaultModelsListConfig holds the default value on creation for the "models_list_config" field. + DefaultModelsListConfig domain.GroupModelsListConfig // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. DefaultRpmLimit int ) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 44b905bd..d5ed0c19 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -467,6 +467,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _c } +// SetModelsListConfig sets the "models_list_config" field. +func (_c *GroupCreate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupCreate { + _c.mutation.SetModelsListConfig(v) + return _c +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_c *GroupCreate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupCreate { + if v != nil { + _c.SetModelsListConfig(*v) + } + return _c +} + // SetRpmLimit sets the "rpm_limit" field. func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate { _c.mutation.SetRpmLimit(v) @@ -698,6 +712,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultMessagesDispatchModelConfig _c.mutation.SetMessagesDispatchModelConfig(v) } + if _, ok := _c.mutation.ModelsListConfig(); !ok { + v := group.DefaultModelsListConfig + _c.mutation.SetModelsListConfig(v) + } if _, ok := _c.mutation.RpmLimit(); !ok { v := group.DefaultRpmLimit _c.mutation.SetRpmLimit(v) @@ -798,6 +816,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} } + if _, ok := _c.mutation.ModelsListConfig(); !ok { + return &ValidationError{Name: "models_list_config", err: errors.New(`ent: missing required field "Group.models_list_config"`)} + } if _, ok := _c.mutation.RpmLimit(); !ok { return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)} } @@ -960,6 +981,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) _node.MessagesDispatchModelConfig = value } + if value, ok := _c.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + _node.ModelsListConfig = value + } if value, ok := _c.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) _node.RpmLimit = value @@ -1642,6 +1667,18 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { return u } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsert) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsert { + u.Set(group.FieldModelsListConfig, v) + return u +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsert) UpdateModelsListConfig() *GroupUpsert { + u.SetExcluded(group.FieldModelsListConfig) + return u +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert { u.Set(group.FieldRpmLimit, v) @@ -2314,6 +2351,20 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { }) } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsertOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetModelsListConfig(v) + }) +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateModelsListConfig() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelsListConfig() + }) +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -3155,6 +3206,20 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { }) } +// SetModelsListConfig sets the "models_list_config" field. +func (u *GroupUpsertBulk) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetModelsListConfig(v) + }) +} + +// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateModelsListConfig() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelsListConfig() + }) +} + // SetRpmLimit sets the "rpm_limit" field. func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index fe55982c..c10d60ec 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -616,6 +616,20 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _u } +// SetModelsListConfig sets the "models_list_config" field. +func (_u *GroupUpdate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdate { + _u.mutation.SetModelsListConfig(v) + return _u +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdate { + if v != nil { + _u.SetModelsListConfig(*v) + } + return _u +} + // SetRpmLimit sets the "rpm_limit" field. func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate { _u.mutation.ResetRpmLimit() @@ -1112,6 +1126,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + } if value, ok := _u.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) } @@ -2012,6 +2029,20 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA return _u } +// SetModelsListConfig sets the "models_list_config" field. +func (_u *GroupUpdateOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdateOne { + _u.mutation.SetModelsListConfig(v) + return _u +} + +// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdateOne { + if v != nil { + _u.SetModelsListConfig(*v) + } + return _u +} + // SetRpmLimit sets the "rpm_limit" field. func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne { _u.mutation.ResetRpmLimit() @@ -2538,6 +2569,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.ModelsListConfig(); ok { + _spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value) + } if value, ok := _u.mutation.RpmLimit(); ok { _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 414eba24..71bfd3b8 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m) } +// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary +// function as UserPlatformQuota mutator. +type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m) +} + // The UserSubscriptionFunc type is an adapter to allow the use of ordinary // function as UserSubscription mutator. type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 95b68e09..5d86e25b 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -42,6 +42,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q) } +// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier. +type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UserPlatformQuotaQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q) +} + +// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UserPlatformQuotaQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q) +} + // The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier. type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error) @@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil case *ent.UserAttributeValueQuery: return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil + case *ent.UserPlatformQuotaQuery: + return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil case *ent.UserSubscriptionQuery: return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil default: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index b1731a35..7abe4c60 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -669,6 +669,7 @@ var ( {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "models_list_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. @@ -1612,6 +1613,53 @@ var ( }, }, } + // UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table. + UserPlatformQuotasColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "platform", Type: field.TypeString, Size: 32}, + {Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table. + UserPlatformQuotasTable = &schema.Table{ + Name: "user_platform_quotas", + Columns: UserPlatformQuotasColumns, + PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "user_platform_quotas_users_platform_quotas", + Columns: []*schema.Column{UserPlatformQuotasColumns[14]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "userplatformquota_user_id_platform", + Unique: true, + Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]}, + Annotation: &entsql.IndexAnnotation{ + Where: "deleted_at IS NULL", + }, + }, + { + Name: "userplatformquota_user_id", + Unique: false, + Columns: []*schema.Column{UserPlatformQuotasColumns[14]}, + }, + }, + } // UserSubscriptionsColumns holds the columns for the "user_subscriptions" table. UserSubscriptionsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1736,6 +1784,7 @@ var ( UserAllowedGroupsTable, UserAttributeDefinitionsTable, UserAttributeValuesTable, + UserPlatformQuotasTable, UserSubscriptionsTable, } ) @@ -1869,6 +1918,10 @@ func init() { UserAttributeValuesTable.Annotation = &entsql.Annotation{ Table: "user_attribute_values", } + UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable + UserPlatformQuotasTable.Annotation = &entsql.Annotation{ + Table: "user_platform_quotas", + } UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index af0edc68..003e25d5 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -46,6 +46,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -92,6 +93,7 @@ const ( TypeUserAllowedGroup = "UserAllowedGroup" TypeUserAttributeDefinition = "UserAttributeDefinition" TypeUserAttributeValue = "UserAttributeValue" + TypeUserPlatformQuota = "UserPlatformQuota" TypeUserSubscription = "UserSubscription" ) @@ -14899,6 +14901,7 @@ type GroupMutation struct { require_privacy_set *bool default_mapped_model *string messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + models_list_config *domain.GroupModelsListConfig rpm_limit *int addrpm_limit *int clearedFields map[string]struct{} @@ -16617,6 +16620,42 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() { m.messages_dispatch_model_config = nil } +// SetModelsListConfig sets the "models_list_config" field. +func (m *GroupMutation) SetModelsListConfig(dmlc domain.GroupModelsListConfig) { + m.models_list_config = &dmlc +} + +// ModelsListConfig returns the value of the "models_list_config" field in the mutation. +func (m *GroupMutation) ModelsListConfig() (r domain.GroupModelsListConfig, exists bool) { + v := m.models_list_config + if v == nil { + return + } + return *v, true +} + +// OldModelsListConfig returns the old "models_list_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelsListConfig(ctx context.Context) (v domain.GroupModelsListConfig, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelsListConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelsListConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelsListConfig: %w", err) + } + return oldValue.ModelsListConfig, nil +} + +// ResetModelsListConfig resets all changes to the "models_list_config" field. +func (m *GroupMutation) ResetModelsListConfig() { + m.models_list_config = nil +} + // SetRpmLimit sets the "rpm_limit" field. func (m *GroupMutation) SetRpmLimit(i int) { m.rpm_limit = &i @@ -17031,7 +17070,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 34) + fields := make([]string, 0, 35) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -17131,6 +17170,9 @@ func (m *GroupMutation) Fields() []string { if m.messages_dispatch_model_config != nil { fields = append(fields, group.FieldMessagesDispatchModelConfig) } + if m.models_list_config != nil { + fields = append(fields, group.FieldModelsListConfig) + } if m.rpm_limit != nil { fields = append(fields, group.FieldRpmLimit) } @@ -17208,6 +17250,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.DefaultMappedModel() case group.FieldMessagesDispatchModelConfig: return m.MessagesDispatchModelConfig() + case group.FieldModelsListConfig: + return m.ModelsListConfig() case group.FieldRpmLimit: return m.RpmLimit() } @@ -17285,6 +17329,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldDefaultMappedModel(ctx) case group.FieldMessagesDispatchModelConfig: return m.OldMessagesDispatchModelConfig(ctx) + case group.FieldModelsListConfig: + return m.OldModelsListConfig(ctx) case group.FieldRpmLimit: return m.OldRpmLimit(ctx) } @@ -17527,6 +17573,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMessagesDispatchModelConfig(v) return nil + case group.FieldModelsListConfig: + v, ok := value.(domain.GroupModelsListConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelsListConfig(v) + return nil case group.FieldRpmLimit: v, ok := value.(int) if !ok { @@ -17910,6 +17963,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMessagesDispatchModelConfig: m.ResetMessagesDispatchModelConfig() return nil + case group.FieldModelsListConfig: + m.ResetModelsListConfig() + return nil case group.FieldRpmLimit: m.ResetRpmLimit() return nil @@ -38160,6 +38216,9 @@ type UserMutation struct { pending_auth_sessions map[int64]struct{} removedpending_auth_sessions map[int64]struct{} clearedpending_auth_sessions bool + platform_quotas map[int64]struct{} + removedplatform_quotas map[int64]struct{} + clearedplatform_quotas bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -39918,6 +39977,60 @@ func (m *UserMutation) ResetPendingAuthSessions() { m.removedpending_auth_sessions = nil } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by ids. +func (m *UserMutation) AddPlatformQuotaIDs(ids ...int64) { + if m.platform_quotas == nil { + m.platform_quotas = make(map[int64]struct{}) + } + for i := range ids { + m.platform_quotas[ids[i]] = struct{}{} + } +} + +// ClearPlatformQuotas clears the "platform_quotas" edge to the UserPlatformQuota entity. +func (m *UserMutation) ClearPlatformQuotas() { + m.clearedplatform_quotas = true +} + +// PlatformQuotasCleared reports if the "platform_quotas" edge to the UserPlatformQuota entity was cleared. +func (m *UserMutation) PlatformQuotasCleared() bool { + return m.clearedplatform_quotas +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (m *UserMutation) RemovePlatformQuotaIDs(ids ...int64) { + if m.removedplatform_quotas == nil { + m.removedplatform_quotas = make(map[int64]struct{}) + } + for i := range ids { + delete(m.platform_quotas, ids[i]) + m.removedplatform_quotas[ids[i]] = struct{}{} + } +} + +// RemovedPlatformQuotas returns the removed IDs of the "platform_quotas" edge to the UserPlatformQuota entity. +func (m *UserMutation) RemovedPlatformQuotasIDs() (ids []int64) { + for id := range m.removedplatform_quotas { + ids = append(ids, id) + } + return +} + +// PlatformQuotasIDs returns the "platform_quotas" edge IDs in the mutation. +func (m *UserMutation) PlatformQuotasIDs() (ids []int64) { + for id := range m.platform_quotas { + ids = append(ids, id) + } + return +} + +// ResetPlatformQuotas resets all changes to the "platform_quotas" edge. +func (m *UserMutation) ResetPlatformQuotas() { + m.platform_quotas = nil + m.clearedplatform_quotas = false + m.removedplatform_quotas = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -40527,7 +40640,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -40564,6 +40677,9 @@ func (m *UserMutation) AddedEdges() []string { if m.pending_auth_sessions != nil { edges = append(edges, user.EdgePendingAuthSessions) } + if m.platform_quotas != nil { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40643,13 +40759,19 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgePlatformQuotas: + ids := make([]ent.Value, 0, len(m.platform_quotas)) + for id := range m.platform_quotas { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -40686,6 +40808,9 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpending_auth_sessions != nil { edges = append(edges, user.EdgePendingAuthSessions) } + if m.removedplatform_quotas != nil { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40765,13 +40890,19 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgePlatformQuotas: + ids := make([]ent.Value, 0, len(m.removedplatform_quotas)) + for id := range m.removedplatform_quotas { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -40808,6 +40939,9 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpending_auth_sessions { edges = append(edges, user.EdgePendingAuthSessions) } + if m.clearedplatform_quotas { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40839,6 +40973,8 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedauth_identities case user.EdgePendingAuthSessions: return m.clearedpending_auth_sessions + case user.EdgePlatformQuotas: + return m.clearedplatform_quotas } return false } @@ -40891,6 +41027,9 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePendingAuthSessions: m.ResetPendingAuthSessions() return nil + case user.EdgePlatformQuotas: + m.ResetPlatformQuotas() + return nil } return fmt.Errorf("unknown User edge %s", name) } @@ -43111,6 +43250,1428 @@ func (m *UserAttributeValueMutation) ResetEdge(name string) error { return fmt.Errorf("unknown UserAttributeValue edge %s", name) } +// UserPlatformQuotaMutation represents an operation that mutates the UserPlatformQuota nodes in the graph. +type UserPlatformQuotaMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + platform *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + daily_usage_usd *float64 + adddaily_usage_usd *float64 + weekly_usage_usd *float64 + addweekly_usage_usd *float64 + monthly_usage_usd *float64 + addmonthly_usage_usd *float64 + daily_window_start *time.Time + weekly_window_start *time.Time + monthly_window_start *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + done bool + oldValue func(context.Context) (*UserPlatformQuota, error) + predicates []predicate.UserPlatformQuota +} + +var _ ent.Mutation = (*UserPlatformQuotaMutation)(nil) + +// userplatformquotaOption allows management of the mutation configuration using functional options. +type userplatformquotaOption func(*UserPlatformQuotaMutation) + +// newUserPlatformQuotaMutation creates new mutation for the UserPlatformQuota entity. +func newUserPlatformQuotaMutation(c config, op Op, opts ...userplatformquotaOption) *UserPlatformQuotaMutation { + m := &UserPlatformQuotaMutation{ + config: c, + op: op, + typ: TypeUserPlatformQuota, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUserPlatformQuotaID sets the ID field of the mutation. +func withUserPlatformQuotaID(id int64) userplatformquotaOption { + return func(m *UserPlatformQuotaMutation) { + var ( + err error + once sync.Once + value *UserPlatformQuota + ) + m.oldValue = func(ctx context.Context) (*UserPlatformQuota, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UserPlatformQuota.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUserPlatformQuota sets the old UserPlatformQuota of the mutation. +func withUserPlatformQuota(node *UserPlatformQuota) userplatformquotaOption { + return func(m *UserPlatformQuotaMutation) { + m.oldValue = func(context.Context) (*UserPlatformQuota, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UserPlatformQuotaMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UserPlatformQuotaMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UserPlatformQuotaMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UserPlatformQuotaMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UserPlatformQuota.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *UserPlatformQuotaMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UserPlatformQuotaMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UserPlatformQuotaMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *UserPlatformQuotaMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *UserPlatformQuotaMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *UserPlatformQuotaMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *UserPlatformQuotaMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserPlatformQuotaMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserPlatformQuotaMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[userplatformquota.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserPlatformQuotaMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, userplatformquota.FieldDeletedAt) +} + +// SetUserID sets the "user_id" field. +func (m *UserPlatformQuotaMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *UserPlatformQuotaMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *UserPlatformQuotaMutation) ResetUserID() { + m.user = nil +} + +// SetPlatform sets the "platform" field. +func (m *UserPlatformQuotaMutation) SetPlatform(s string) { + m.platform = &s +} + +// Platform returns the value of the "platform" field in the mutation. +func (m *UserPlatformQuotaMutation) Platform() (r string, exists bool) { + v := m.platform + if v == nil { + return + } + return *v, true +} + +// OldPlatform returns the old "platform" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldPlatform(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatform requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + } + return oldValue.Platform, nil +} + +// ResetPlatform resets all changes to the "platform" field. +func (m *UserPlatformQuotaMutation) ResetPlatform() { + m.platform = nil +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetDailyLimitUsd(f float64) { + m.daily_limit_usd = &f + m.adddaily_limit_usd = nil +} + +// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyLimitUsd() (r float64, exists bool) { + v := m.daily_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + } + return oldValue.DailyLimitUsd, nil +} + +// AddDailyLimitUsd adds f to the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddDailyLimitUsd(f float64) { + if m.adddaily_limit_usd != nil { + *m.adddaily_limit_usd += f + } else { + m.adddaily_limit_usd = &f + } +} + +// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedDailyLimitUsd() (r float64, exists bool) { + v := m.adddaily_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + m.clearedFields[userplatformquota.FieldDailyLimitUsd] = struct{}{} +} + +// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DailyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDailyLimitUsd] + return ok +} + +// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldDailyLimitUsd) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetWeeklyLimitUsd(f float64) { + m.weekly_limit_usd = &f + m.addweekly_limit_usd = nil +} + +// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyLimitUsd() (r float64, exists bool) { + v := m.weekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) + } + return oldValue.WeeklyLimitUsd, nil +} + +// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddWeeklyLimitUsd(f float64) { + if m.addweekly_limit_usd != nil { + *m.addweekly_limit_usd += f + } else { + m.addweekly_limit_usd = &f + } +} + +// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { + v := m.addweekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + m.clearedFields[userplatformquota.FieldWeeklyLimitUsd] = struct{}{} +} + +// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) WeeklyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldWeeklyLimitUsd] + return ok +} + +// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldWeeklyLimitUsd) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetMonthlyLimitUsd(f float64) { + m.monthly_limit_usd = &f + m.addmonthly_limit_usd = nil +} + +// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyLimitUsd() (r float64, exists bool) { + v := m.monthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + } + return oldValue.MonthlyLimitUsd, nil +} + +// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddMonthlyLimitUsd(f float64) { + if m.addmonthly_limit_usd != nil { + *m.addmonthly_limit_usd += f + } else { + m.addmonthly_limit_usd = &f + } +} + +// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { + v := m.addmonthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + m.clearedFields[userplatformquota.FieldMonthlyLimitUsd] = struct{}{} +} + +// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) MonthlyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldMonthlyLimitUsd] + return ok +} + +// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldMonthlyLimitUsd) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetDailyUsageUsd(f float64) { + m.daily_usage_usd = &f + m.adddaily_usage_usd = nil +} + +// DailyUsageUsd returns the value of the "daily_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyUsageUsd() (r float64, exists bool) { + v := m.daily_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldDailyUsageUsd returns the old "daily_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyUsageUsd: %w", err) + } + return oldValue.DailyUsageUsd, nil +} + +// AddDailyUsageUsd adds f to the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddDailyUsageUsd(f float64) { + if m.adddaily_usage_usd != nil { + *m.adddaily_usage_usd += f + } else { + m.adddaily_usage_usd = &f + } +} + +// AddedDailyUsageUsd returns the value that was added to the "daily_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedDailyUsageUsd() (r float64, exists bool) { + v := m.adddaily_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetDailyUsageUsd resets all changes to the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetDailyUsageUsd() { + m.daily_usage_usd = nil + m.adddaily_usage_usd = nil +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetWeeklyUsageUsd(f float64) { + m.weekly_usage_usd = &f + m.addweekly_usage_usd = nil +} + +// WeeklyUsageUsd returns the value of the "weekly_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyUsageUsd() (r float64, exists bool) { + v := m.weekly_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldWeeklyUsageUsd returns the old "weekly_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyUsageUsd: %w", err) + } + return oldValue.WeeklyUsageUsd, nil +} + +// AddWeeklyUsageUsd adds f to the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddWeeklyUsageUsd(f float64) { + if m.addweekly_usage_usd != nil { + *m.addweekly_usage_usd += f + } else { + m.addweekly_usage_usd = &f + } +} + +// AddedWeeklyUsageUsd returns the value that was added to the "weekly_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedWeeklyUsageUsd() (r float64, exists bool) { + v := m.addweekly_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetWeeklyUsageUsd resets all changes to the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyUsageUsd() { + m.weekly_usage_usd = nil + m.addweekly_usage_usd = nil +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetMonthlyUsageUsd(f float64) { + m.monthly_usage_usd = &f + m.addmonthly_usage_usd = nil +} + +// MonthlyUsageUsd returns the value of the "monthly_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyUsageUsd() (r float64, exists bool) { + v := m.monthly_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldMonthlyUsageUsd returns the old "monthly_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyUsageUsd: %w", err) + } + return oldValue.MonthlyUsageUsd, nil +} + +// AddMonthlyUsageUsd adds f to the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddMonthlyUsageUsd(f float64) { + if m.addmonthly_usage_usd != nil { + *m.addmonthly_usage_usd += f + } else { + m.addmonthly_usage_usd = &f + } +} + +// AddedMonthlyUsageUsd returns the value that was added to the "monthly_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedMonthlyUsageUsd() (r float64, exists bool) { + v := m.addmonthly_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetMonthlyUsageUsd resets all changes to the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyUsageUsd() { + m.monthly_usage_usd = nil + m.addmonthly_usage_usd = nil +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) SetDailyWindowStart(t time.Time) { + m.daily_window_start = &t +} + +// DailyWindowStart returns the value of the "daily_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyWindowStart() (r time.Time, exists bool) { + v := m.daily_window_start + if v == nil { + return + } + return *v, true +} + +// OldDailyWindowStart returns the old "daily_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyWindowStart: %w", err) + } + return oldValue.DailyWindowStart, nil +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) ClearDailyWindowStart() { + m.daily_window_start = nil + m.clearedFields[userplatformquota.FieldDailyWindowStart] = struct{}{} +} + +// DailyWindowStartCleared returns if the "daily_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DailyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDailyWindowStart] + return ok +} + +// ResetDailyWindowStart resets all changes to the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) ResetDailyWindowStart() { + m.daily_window_start = nil + delete(m.clearedFields, userplatformquota.FieldDailyWindowStart) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) SetWeeklyWindowStart(t time.Time) { + m.weekly_window_start = &t +} + +// WeeklyWindowStart returns the value of the "weekly_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyWindowStart() (r time.Time, exists bool) { + v := m.weekly_window_start + if v == nil { + return + } + return *v, true +} + +// OldWeeklyWindowStart returns the old "weekly_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyWindowStart: %w", err) + } + return oldValue.WeeklyWindowStart, nil +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) ClearWeeklyWindowStart() { + m.weekly_window_start = nil + m.clearedFields[userplatformquota.FieldWeeklyWindowStart] = struct{}{} +} + +// WeeklyWindowStartCleared returns if the "weekly_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) WeeklyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldWeeklyWindowStart] + return ok +} + +// ResetWeeklyWindowStart resets all changes to the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyWindowStart() { + m.weekly_window_start = nil + delete(m.clearedFields, userplatformquota.FieldWeeklyWindowStart) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) SetMonthlyWindowStart(t time.Time) { + m.monthly_window_start = &t +} + +// MonthlyWindowStart returns the value of the "monthly_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyWindowStart() (r time.Time, exists bool) { + v := m.monthly_window_start + if v == nil { + return + } + return *v, true +} + +// OldMonthlyWindowStart returns the old "monthly_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyWindowStart: %w", err) + } + return oldValue.MonthlyWindowStart, nil +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) ClearMonthlyWindowStart() { + m.monthly_window_start = nil + m.clearedFields[userplatformquota.FieldMonthlyWindowStart] = struct{}{} +} + +// MonthlyWindowStartCleared returns if the "monthly_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) MonthlyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldMonthlyWindowStart] + return ok +} + +// ResetMonthlyWindowStart resets all changes to the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyWindowStart() { + m.monthly_window_start = nil + delete(m.clearedFields, userplatformquota.FieldMonthlyWindowStart) +} + +// ClearUser clears the "user" edge to the User entity. +func (m *UserPlatformQuotaMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[userplatformquota.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *UserPlatformQuotaMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *UserPlatformQuotaMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *UserPlatformQuotaMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the UserPlatformQuotaMutation builder. +func (m *UserPlatformQuotaMutation) Where(ps ...predicate.UserPlatformQuota) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UserPlatformQuotaMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UserPlatformQuotaMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UserPlatformQuota, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UserPlatformQuotaMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UserPlatformQuotaMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UserPlatformQuota). +func (m *UserPlatformQuotaMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UserPlatformQuotaMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, userplatformquota.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, userplatformquota.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, userplatformquota.FieldDeletedAt) + } + if m.user != nil { + fields = append(fields, userplatformquota.FieldUserID) + } + if m.platform != nil { + fields = append(fields, userplatformquota.FieldPlatform) + } + if m.daily_limit_usd != nil { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.weekly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.monthly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.daily_usage_usd != nil { + fields = append(fields, userplatformquota.FieldDailyUsageUsd) + } + if m.weekly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyUsageUsd) + } + if m.monthly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyUsageUsd) + } + if m.daily_window_start != nil { + fields = append(fields, userplatformquota.FieldDailyWindowStart) + } + if m.weekly_window_start != nil { + fields = append(fields, userplatformquota.FieldWeeklyWindowStart) + } + if m.monthly_window_start != nil { + fields = append(fields, userplatformquota.FieldMonthlyWindowStart) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UserPlatformQuotaMutation) Field(name string) (ent.Value, bool) { + switch name { + case userplatformquota.FieldCreatedAt: + return m.CreatedAt() + case userplatformquota.FieldUpdatedAt: + return m.UpdatedAt() + case userplatformquota.FieldDeletedAt: + return m.DeletedAt() + case userplatformquota.FieldUserID: + return m.UserID() + case userplatformquota.FieldPlatform: + return m.Platform() + case userplatformquota.FieldDailyLimitUsd: + return m.DailyLimitUsd() + case userplatformquota.FieldWeeklyLimitUsd: + return m.WeeklyLimitUsd() + case userplatformquota.FieldMonthlyLimitUsd: + return m.MonthlyLimitUsd() + case userplatformquota.FieldDailyUsageUsd: + return m.DailyUsageUsd() + case userplatformquota.FieldWeeklyUsageUsd: + return m.WeeklyUsageUsd() + case userplatformquota.FieldMonthlyUsageUsd: + return m.MonthlyUsageUsd() + case userplatformquota.FieldDailyWindowStart: + return m.DailyWindowStart() + case userplatformquota.FieldWeeklyWindowStart: + return m.WeeklyWindowStart() + case userplatformquota.FieldMonthlyWindowStart: + return m.MonthlyWindowStart() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UserPlatformQuotaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case userplatformquota.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case userplatformquota.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case userplatformquota.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case userplatformquota.FieldUserID: + return m.OldUserID(ctx) + case userplatformquota.FieldPlatform: + return m.OldPlatform(ctx) + case userplatformquota.FieldDailyLimitUsd: + return m.OldDailyLimitUsd(ctx) + case userplatformquota.FieldWeeklyLimitUsd: + return m.OldWeeklyLimitUsd(ctx) + case userplatformquota.FieldMonthlyLimitUsd: + return m.OldMonthlyLimitUsd(ctx) + case userplatformquota.FieldDailyUsageUsd: + return m.OldDailyUsageUsd(ctx) + case userplatformquota.FieldWeeklyUsageUsd: + return m.OldWeeklyUsageUsd(ctx) + case userplatformquota.FieldMonthlyUsageUsd: + return m.OldMonthlyUsageUsd(ctx) + case userplatformquota.FieldDailyWindowStart: + return m.OldDailyWindowStart(ctx) + case userplatformquota.FieldWeeklyWindowStart: + return m.OldWeeklyWindowStart(ctx) + case userplatformquota.FieldMonthlyWindowStart: + return m.OldMonthlyWindowStart(ctx) + } + return nil, fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserPlatformQuotaMutation) SetField(name string, value ent.Value) error { + switch name { + case userplatformquota.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case userplatformquota.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case userplatformquota.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case userplatformquota.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case userplatformquota.FieldPlatform: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatform(v) + return nil + case userplatformquota.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyLimitUsd(v) + return nil + case userplatformquota.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyLimitUsd(v) + return nil + case userplatformquota.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyLimitUsd(v) + return nil + case userplatformquota.FieldDailyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyUsageUsd(v) + return nil + case userplatformquota.FieldWeeklyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyUsageUsd(v) + return nil + case userplatformquota.FieldMonthlyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyUsageUsd(v) + return nil + case userplatformquota.FieldDailyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyWindowStart(v) + return nil + case userplatformquota.FieldWeeklyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyWindowStart(v) + return nil + case userplatformquota.FieldMonthlyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyWindowStart(v) + return nil + } + return fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UserPlatformQuotaMutation) AddedFields() []string { + var fields []string + if m.adddaily_limit_usd != nil { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.addweekly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.addmonthly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.adddaily_usage_usd != nil { + fields = append(fields, userplatformquota.FieldDailyUsageUsd) + } + if m.addweekly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyUsageUsd) + } + if m.addmonthly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyUsageUsd) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UserPlatformQuotaMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case userplatformquota.FieldDailyLimitUsd: + return m.AddedDailyLimitUsd() + case userplatformquota.FieldWeeklyLimitUsd: + return m.AddedWeeklyLimitUsd() + case userplatformquota.FieldMonthlyLimitUsd: + return m.AddedMonthlyLimitUsd() + case userplatformquota.FieldDailyUsageUsd: + return m.AddedDailyUsageUsd() + case userplatformquota.FieldWeeklyUsageUsd: + return m.AddedWeeklyUsageUsd() + case userplatformquota.FieldMonthlyUsageUsd: + return m.AddedMonthlyUsageUsd() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserPlatformQuotaMutation) AddField(name string, value ent.Value) error { + switch name { + case userplatformquota.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyLimitUsd(v) + return nil + case userplatformquota.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyLimitUsd(v) + return nil + case userplatformquota.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyLimitUsd(v) + return nil + case userplatformquota.FieldDailyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyUsageUsd(v) + return nil + case userplatformquota.FieldWeeklyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyUsageUsd(v) + return nil + case userplatformquota.FieldMonthlyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyUsageUsd(v) + return nil + } + return fmt.Errorf("unknown UserPlatformQuota numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UserPlatformQuotaMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(userplatformquota.FieldDeletedAt) { + fields = append(fields, userplatformquota.FieldDeletedAt) + } + if m.FieldCleared(userplatformquota.FieldDailyLimitUsd) { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldWeeklyLimitUsd) { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldMonthlyLimitUsd) { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldDailyWindowStart) { + fields = append(fields, userplatformquota.FieldDailyWindowStart) + } + if m.FieldCleared(userplatformquota.FieldWeeklyWindowStart) { + fields = append(fields, userplatformquota.FieldWeeklyWindowStart) + } + if m.FieldCleared(userplatformquota.FieldMonthlyWindowStart) { + fields = append(fields, userplatformquota.FieldMonthlyWindowStart) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UserPlatformQuotaMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UserPlatformQuotaMutation) ClearField(name string) error { + switch name { + case userplatformquota.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case userplatformquota.FieldDailyLimitUsd: + m.ClearDailyLimitUsd() + return nil + case userplatformquota.FieldWeeklyLimitUsd: + m.ClearWeeklyLimitUsd() + return nil + case userplatformquota.FieldMonthlyLimitUsd: + m.ClearMonthlyLimitUsd() + return nil + case userplatformquota.FieldDailyWindowStart: + m.ClearDailyWindowStart() + return nil + case userplatformquota.FieldWeeklyWindowStart: + m.ClearWeeklyWindowStart() + return nil + case userplatformquota.FieldMonthlyWindowStart: + m.ClearMonthlyWindowStart() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UserPlatformQuotaMutation) ResetField(name string) error { + switch name { + case userplatformquota.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case userplatformquota.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case userplatformquota.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case userplatformquota.FieldUserID: + m.ResetUserID() + return nil + case userplatformquota.FieldPlatform: + m.ResetPlatform() + return nil + case userplatformquota.FieldDailyLimitUsd: + m.ResetDailyLimitUsd() + return nil + case userplatformquota.FieldWeeklyLimitUsd: + m.ResetWeeklyLimitUsd() + return nil + case userplatformquota.FieldMonthlyLimitUsd: + m.ResetMonthlyLimitUsd() + return nil + case userplatformquota.FieldDailyUsageUsd: + m.ResetDailyUsageUsd() + return nil + case userplatformquota.FieldWeeklyUsageUsd: + m.ResetWeeklyUsageUsd() + return nil + case userplatformquota.FieldMonthlyUsageUsd: + m.ResetMonthlyUsageUsd() + return nil + case userplatformquota.FieldDailyWindowStart: + m.ResetDailyWindowStart() + return nil + case userplatformquota.FieldWeeklyWindowStart: + m.ResetWeeklyWindowStart() + return nil + case userplatformquota.FieldMonthlyWindowStart: + m.ResetMonthlyWindowStart() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UserPlatformQuotaMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, userplatformquota.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UserPlatformQuotaMutation) AddedIDs(name string) []ent.Value { + switch name { + case userplatformquota.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UserPlatformQuotaMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UserPlatformQuotaMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UserPlatformQuotaMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, userplatformquota.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UserPlatformQuotaMutation) EdgeCleared(name string) bool { + switch name { + case userplatformquota.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UserPlatformQuotaMutation) ClearEdge(name string) error { + switch name { + case userplatformquota.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UserPlatformQuotaMutation) ResetEdge(name string) error { + switch name { + case userplatformquota.EdgeUser: + m.ResetUser() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota edge %s", name) +} + // UserSubscriptionMutation represents an operation that mutates the UserSubscription nodes in the graph. type UserSubscriptionMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index dc86471e..ab4d7d18 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector) // UserAttributeValue is the predicate function for userattributevalue builders. type UserAttributeValue func(*sql.Selector) +// UserPlatformQuota is the predicate function for userplatformquota builders. +type UserPlatformQuota func(*sql.Selector) + // UserSubscription is the predicate function for usersubscription builders. type UserSubscription func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6d541e2f..fdb837e8 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -39,6 +39,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -869,8 +870,12 @@ func init() { groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) + // groupDescModelsListConfig is the schema descriptor for models_list_config field. + groupDescModelsListConfig := groupFields[30].Descriptor() + // group.DefaultModelsListConfig holds the default value on creation for the models_list_config field. + group.DefaultModelsListConfig = groupDescModelsListConfig.Default.(domain.GroupModelsListConfig) // groupDescRpmLimit is the schema descriptor for rpm_limit field. - groupDescRpmLimit := groupFields[30].Descriptor() + groupDescRpmLimit := groupFields[31].Descriptor() // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() @@ -1997,6 +2002,56 @@ func init() { userattributevalueDescValue := userattributevalueFields[2].Descriptor() // userattributevalue.DefaultValue holds the default value on creation for the value field. userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string) + userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin() + userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks() + userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0] + userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors() + userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0] + userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields() + _ = userplatformquotaMixinFields0 + userplatformquotaFields := schema.UserPlatformQuota{}.Fields() + _ = userplatformquotaFields + // userplatformquotaDescCreatedAt is the schema descriptor for created_at field. + userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor() + // userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field. + userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time) + // userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field. + userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor() + // userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field. + userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time) + // userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time) + // userplatformquotaDescPlatform is the schema descriptor for platform field. + userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor() + // userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save. + userplatformquota.PlatformValidator = func() func(string) error { + validators := userplatformquotaDescPlatform.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(platform string) error { + for _, fn := range fns { + if err := fn(platform); err != nil { + return err + } + } + return nil + } + }() + // userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field. + userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor() + // userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field. + userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64) + // userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field. + userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor() + // userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field. + userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64) + // userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field. + userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor() + // userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field. + userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64) usersubscriptionMixin := schema.UserSubscription{}.Mixin() usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks() usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0] diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d47e8710..2a1715f8 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -155,6 +155,10 @@ func (Group) Fields() []ent.Field { Default(domain.OpenAIMessagesDispatchModelConfig{}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), + field.JSON("models_list_config", domain.GroupModelsListConfig{}). + Default(domain.GroupModelsListConfig{}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度"), // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。 field.Int("rpm_limit"). diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index c6e04273..127b5af9 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge { edge.To("auth_identities", AuthIdentity.Type). Annotations(entsql.OnDelete(entsql.Cascade)), edge.To("pending_auth_sessions", PendingAuthSession.Type), + edge.To("platform_quotas", UserPlatformQuota.Type), } } diff --git a/backend/ent/schema/user_platform_quota.go b/backend/ent/schema/user_platform_quota.go new file mode 100644 index 00000000..8fd8acc0 --- /dev/null +++ b/backend/ent/schema/user_platform_quota.go @@ -0,0 +1,113 @@ +package schema + +import ( + "fmt" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" +) + +// UserPlatformQuota holds the schema definition for per-user per-platform quota. +type UserPlatformQuota struct { + ent.Schema +} + +func (UserPlatformQuota) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "user_platform_quotas"}, + } +} + +func (UserPlatformQuota) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + mixins.SoftDeleteMixin{}, + } +} + +func (UserPlatformQuota) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("platform"). + MaxLen(32). + NotEmpty(). + Validate(func(s string) error { + // 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms; + // 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。 + switch s { + case "anthropic", "openai", "gemini", "antigravity": + return nil + default: + return fmt.Errorf("platform %q is not allowed", s) + } + }), + + // 日 / 周 / 月 USD 上限: + // nil / not set → 无限额(完全放行) + // 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立) + // > 0 → USD 限额上限 + field.Float("daily_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("weekly_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("monthly_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + + // 当前窗口已用量(USD,preflight 时与 limit 比较) + field.Float("daily_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("weekly_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("monthly_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + + // 窗口起点(NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底) + field.Time("daily_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("weekly_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("monthly_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (UserPlatformQuota) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("platform_quotas"). + Field("user_id"). + Unique(). + Required(), + } +} + +func (UserPlatformQuota) Indexes() []ent.Index { + return []ent.Index{ + // 软删除友好:只对未删记录唯一 + index.Fields("user_id", "platform"). + Unique(). + Annotations(entsql.IndexWhere("deleted_at IS NULL")), + index.Fields("user_id"), + } +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 611028e9..846cfcd4 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -80,6 +80,8 @@ type Tx struct { UserAttributeDefinition *UserAttributeDefinitionClient // UserAttributeValue is the client for interacting with the UserAttributeValue builders. UserAttributeValue *UserAttributeValueClient + // UserPlatformQuota is the client for interacting with the UserPlatformQuota builders. + UserPlatformQuota *UserPlatformQuotaClient // UserSubscription is the client for interacting with the UserSubscription builders. UserSubscription *UserSubscriptionClient @@ -246,6 +248,7 @@ func (tx *Tx) init() { tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config) tx.UserAttributeValue = NewUserAttributeValueClient(tx.config) + tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config) tx.UserSubscription = NewUserSubscriptionClient(tx.config) } diff --git a/backend/ent/user.go b/backend/ent/user.go index 06670444..486f2f64 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -95,11 +95,13 @@ type UserEdges struct { AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` // PendingAuthSessions holds the value of the pending_auth_sessions edge. PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` + // PlatformQuotas holds the value of the platform_quotas edge. + PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [13]bool + loadedTypes [14]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { return nil, &NotLoadedError{edge: "pending_auth_sessions"} } +// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) { + if e.loadedTypes[12] { + return e.PlatformQuotas, nil + } + return nil, &NotLoadedError{edge: "platform_quotas"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[12] { + if e.loadedTypes[13] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { return NewUserClient(_m.config).QueryPendingAuthSessions(_m) } +// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity. +func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery { + return NewUserClient(_m.config).QueryPlatformQuotas(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index e11a8a32..ff40445b 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -85,6 +85,8 @@ const ( EdgeAuthIdentities = "auth_identities" // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. EdgePendingAuthSessions = "pending_auth_sessions" + // EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations. + EdgePlatformQuotas = "platform_quotas" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -171,6 +173,13 @@ const ( PendingAuthSessionsInverseTable = "pending_auth_sessions" // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. PendingAuthSessionsColumn = "target_user_id" + // PlatformQuotasTable is the table that holds the platform_quotas relation/edge. + PlatformQuotasTable = "user_platform_quotas" + // PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity. + // It exists in this package in order to avoid circular dependency with the "userplatformquota" package. + PlatformQuotasInverseTable = "user_platform_quotas" + // PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge. + PlatformQuotasColumn = "user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti } } +// ByPlatformQuotasCount orders the results by platform_quotas count. +func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...) + } +} + +// ByPlatformQuotas orders the results by platform_quotas terms. +func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), ) } +func newPlatformQuotasStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PlatformQuotasInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 05d3b35b..a18cf497 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate }) } +// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge. +func HasPlatformQuotas() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates). +func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPlatformQuotasStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index b4161128..92f1bd5e 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea return _c.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate { + _c.mutation.AddPlatformQuotaIDs(ids...) + return _c +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index f1ee5cfe..86f4eccd 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -48,6 +49,7 @@ type UserQuery struct { withPaymentOrders *PaymentOrderQuery withAuthIdentities *AuthIdentityQuery withPendingAuthSessions *PendingAuthSessionQuery + withPlatformQuotas *UserPlatformQuotaQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { return query } +// QueryPlatformQuotas chains the current query on the "platform_quotas" edge. +func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery { + query := (&UserPlatformQuotaClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery { withPaymentOrders: _q.withPaymentOrders.Clone(), withAuthIdentities: _q.withAuthIdentities.Clone(), withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), + withPlatformQuotas: _q.withPlatformQuotas.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue return _q } +// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to +// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery { + query := (&UserPlatformQuotaClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPlatformQuotas = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [13]bool{ + loadedTypes = [14]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withPaymentOrders != nil, _q.withAuthIdentities != nil, _q.withPendingAuthSessions != nil, + _q.withPlatformQuotas != nil, _q.withUserAllowedGroups != nil, } ) @@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withPlatformQuotas; query != nil { + if err := _q.loadPlatformQuotas(ctx, query, nodes, + func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} }, + func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending } return nil } +func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(userplatformquota.FieldUserID) + } + query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index f1d759ce..67d3f8e6 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -23,6 +23,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda return _u.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPlatformQuotaIDs(ids...) + return _u +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU return _u.RemovePendingAuthSessionIDs(ids...) } +// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate { + _u.mutation.ClearPlatformQuotas() + return _u +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs. +func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePlatformQuotaIDs(ids...) + return _u +} + +// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities. +func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePlatformQuotaIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU return _u.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPlatformQuotaIDs(ids...) + return _u +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us return _u.RemovePendingAuthSessionIDs(ids...) } +// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne { + _u.mutation.ClearPlatformQuotas() + return _u +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs. +func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePlatformQuotaIDs(ids...) + return _u +} + +// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities. +func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePlatformQuotaIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/userplatformquota.go b/backend/ent/userplatformquota.go new file mode 100644 index 00000000..d00a1f9a --- /dev/null +++ b/backend/ent/userplatformquota.go @@ -0,0 +1,301 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuota is the model entity for the UserPlatformQuota schema. +type UserPlatformQuota struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // Platform holds the value of the "platform" field. + Platform string `json:"platform,omitempty"` + // DailyLimitUsd holds the value of the "daily_limit_usd" field. + DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"` + // WeeklyLimitUsd holds the value of the "weekly_limit_usd" field. + WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"` + // MonthlyLimitUsd holds the value of the "monthly_limit_usd" field. + MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` + // DailyUsageUsd holds the value of the "daily_usage_usd" field. + DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"` + // WeeklyUsageUsd holds the value of the "weekly_usage_usd" field. + WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"` + // MonthlyUsageUsd holds the value of the "monthly_usage_usd" field. + MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"` + // DailyWindowStart holds the value of the "daily_window_start" field. + DailyWindowStart *time.Time `json:"daily_window_start,omitempty"` + // WeeklyWindowStart holds the value of the "weekly_window_start" field. + WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"` + // MonthlyWindowStart holds the value of the "monthly_window_start" field. + MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the UserPlatformQuotaQuery when eager-loading is set. + Edges UserPlatformQuotaEdges `json:"edges"` + selectValues sql.SelectValues +} + +// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph. +type UserPlatformQuotaEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd: + values[i] = new(sql.NullFloat64) + case userplatformquota.FieldID, userplatformquota.FieldUserID: + values[i] = new(sql.NullInt64) + case userplatformquota.FieldPlatform: + values[i] = new(sql.NullString) + case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UserPlatformQuota fields. +func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case userplatformquota.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case userplatformquota.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case userplatformquota.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case userplatformquota.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + _m.DeletedAt = new(time.Time) + *_m.DeletedAt = value.Time + } + case userplatformquota.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case userplatformquota.FieldPlatform: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field platform", values[i]) + } else if value.Valid { + _m.Platform = value.String + } + case userplatformquota.FieldDailyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i]) + } else if value.Valid { + _m.DailyLimitUsd = new(float64) + *_m.DailyLimitUsd = value.Float64 + } + case userplatformquota.FieldWeeklyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i]) + } else if value.Valid { + _m.WeeklyLimitUsd = new(float64) + *_m.WeeklyLimitUsd = value.Float64 + } + case userplatformquota.FieldMonthlyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i]) + } else if value.Valid { + _m.MonthlyLimitUsd = new(float64) + *_m.MonthlyLimitUsd = value.Float64 + } + case userplatformquota.FieldDailyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i]) + } else if value.Valid { + _m.DailyUsageUsd = value.Float64 + } + case userplatformquota.FieldWeeklyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i]) + } else if value.Valid { + _m.WeeklyUsageUsd = value.Float64 + } + case userplatformquota.FieldMonthlyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i]) + } else if value.Valid { + _m.MonthlyUsageUsd = value.Float64 + } + case userplatformquota.FieldDailyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field daily_window_start", values[i]) + } else if value.Valid { + _m.DailyWindowStart = new(time.Time) + *_m.DailyWindowStart = value.Time + } + case userplatformquota.FieldWeeklyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i]) + } else if value.Valid { + _m.WeeklyWindowStart = new(time.Time) + *_m.WeeklyWindowStart = value.Time + } + case userplatformquota.FieldMonthlyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i]) + } else if value.Valid { + _m.MonthlyWindowStart = new(time.Time) + *_m.MonthlyWindowStart = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota. +// This includes values selected through modifiers, order, etc. +func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the UserPlatformQuota entity. +func (_m *UserPlatformQuota) QueryUser() *UserQuery { + return NewUserPlatformQuotaClient(_m.config).QueryUser(_m) +} + +// Update returns a builder for updating this UserPlatformQuota. +// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne { + return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: UserPlatformQuota is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UserPlatformQuota) String() string { + var builder strings.Builder + builder.WriteString("UserPlatformQuota(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("platform=") + builder.WriteString(_m.Platform) + builder.WriteString(", ") + if v := _m.DailyLimitUsd; v != nil { + builder.WriteString("daily_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.WeeklyLimitUsd; v != nil { + builder.WriteString("weekly_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.MonthlyLimitUsd; v != nil { + builder.WriteString("monthly_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("daily_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd)) + builder.WriteString(", ") + builder.WriteString("weekly_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd)) + builder.WriteString(", ") + builder.WriteString("monthly_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd)) + builder.WriteString(", ") + if v := _m.DailyWindowStart; v != nil { + builder.WriteString("daily_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.WeeklyWindowStart; v != nil { + builder.WriteString("weekly_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.MonthlyWindowStart; v != nil { + builder.WriteString("monthly_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota. +type UserPlatformQuotaSlice []*UserPlatformQuota diff --git a/backend/ent/userplatformquota/userplatformquota.go b/backend/ent/userplatformquota/userplatformquota.go new file mode 100644 index 00000000..01903853 --- /dev/null +++ b/backend/ent/userplatformquota/userplatformquota.go @@ -0,0 +1,202 @@ +// Code generated by ent, DO NOT EDIT. + +package userplatformquota + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the userplatformquota type in the database. + Label = "user_platform_quota" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldPlatform holds the string denoting the platform field in the database. + FieldPlatform = "platform" + // FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database. + FieldDailyLimitUsd = "daily_limit_usd" + // FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database. + FieldWeeklyLimitUsd = "weekly_limit_usd" + // FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database. + FieldMonthlyLimitUsd = "monthly_limit_usd" + // FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database. + FieldDailyUsageUsd = "daily_usage_usd" + // FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database. + FieldWeeklyUsageUsd = "weekly_usage_usd" + // FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database. + FieldMonthlyUsageUsd = "monthly_usage_usd" + // FieldDailyWindowStart holds the string denoting the daily_window_start field in the database. + FieldDailyWindowStart = "daily_window_start" + // FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database. + FieldWeeklyWindowStart = "weekly_window_start" + // FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database. + FieldMonthlyWindowStart = "monthly_window_start" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // Table holds the table name of the userplatformquota in the database. + Table = "user_platform_quotas" + // UserTable is the table that holds the user relation/edge. + UserTable = "user_platform_quotas" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" +) + +// Columns holds all SQL columns for userplatformquota fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldUserID, + FieldPlatform, + FieldDailyLimitUsd, + FieldWeeklyLimitUsd, + FieldMonthlyLimitUsd, + FieldDailyUsageUsd, + FieldWeeklyUsageUsd, + FieldMonthlyUsageUsd, + FieldDailyWindowStart, + FieldWeeklyWindowStart, + FieldMonthlyWindowStart, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/Wei-Shaw/sub2api/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // PlatformValidator is a validator for the "platform" field. It is called by the builders before save. + PlatformValidator func(string) error + // DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field. + DefaultDailyUsageUsd float64 + // DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field. + DefaultWeeklyUsageUsd float64 + // DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field. + DefaultMonthlyUsageUsd float64 +) + +// OrderOption defines the ordering options for the UserPlatformQuota queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByPlatform orders the results by the platform field. +func ByPlatform(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPlatform, opts...).ToFunc() +} + +// ByDailyLimitUsd orders the results by the daily_limit_usd field. +func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc() +} + +// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field. +func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc() +} + +// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field. +func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc() +} + +// ByDailyUsageUsd orders the results by the daily_usage_usd field. +func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc() +} + +// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field. +func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc() +} + +// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field. +func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc() +} + +// ByDailyWindowStart orders the results by the daily_window_start field. +func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc() +} + +// ByWeeklyWindowStart orders the results by the weekly_window_start field. +func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc() +} + +// ByMonthlyWindowStart orders the results by the monthly_window_start field. +func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} diff --git a/backend/ent/userplatformquota/where.go b/backend/ent/userplatformquota/where.go new file mode 100644 index 00000000..37d371c6 --- /dev/null +++ b/backend/ent/userplatformquota/where.go @@ -0,0 +1,799 @@ +// Code generated by ent, DO NOT EDIT. + +package userplatformquota + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v)) +} + +// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ. +func Platform(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v)) +} + +// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ. +func DailyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v)) +} + +// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ. +func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v)) +} + +// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ. +func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v)) +} + +// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ. +func DailyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v)) +} + +// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ. +func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v)) +} + +// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ. +func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v)) +} + +// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ. +func DailyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v)) +} + +// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ. +func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v)) +} + +// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ. +func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...)) +} + +// PlatformEQ applies the EQ predicate on the "platform" field. +func PlatformEQ(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v)) +} + +// PlatformNEQ applies the NEQ predicate on the "platform" field. +func PlatformNEQ(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v)) +} + +// PlatformIn applies the In predicate on the "platform" field. +func PlatformIn(vs ...string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...)) +} + +// PlatformNotIn applies the NotIn predicate on the "platform" field. +func PlatformNotIn(vs ...string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...)) +} + +// PlatformGT applies the GT predicate on the "platform" field. +func PlatformGT(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v)) +} + +// PlatformGTE applies the GTE predicate on the "platform" field. +func PlatformGTE(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v)) +} + +// PlatformLT applies the LT predicate on the "platform" field. +func PlatformLT(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v)) +} + +// PlatformLTE applies the LTE predicate on the "platform" field. +func PlatformLTE(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v)) +} + +// PlatformContains applies the Contains predicate on the "platform" field. +func PlatformContains(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v)) +} + +// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field. +func PlatformHasPrefix(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v)) +} + +// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field. +func PlatformHasSuffix(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v)) +} + +// PlatformEqualFold applies the EqualFold predicate on the "platform" field. +func PlatformEqualFold(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v)) +} + +// PlatformContainsFold applies the ContainsFold predicate on the "platform" field. +func PlatformContainsFold(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v)) +} + +// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field. +func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field. +func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field. +func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...)) +} + +// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field. +func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...)) +} + +// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field. +func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field. +func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field. +func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field. +func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field. +func DailyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd)) +} + +// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field. +func DailyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd)) +} + +// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...)) +} + +// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...)) +} + +// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd)) +} + +// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd)) +} + +// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...)) +} + +// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...)) +} + +// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd)) +} + +// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd)) +} + +// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field. +func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field. +func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field. +func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...)) +} + +// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field. +func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...)) +} + +// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field. +func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field. +func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field. +func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field. +func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v)) +} + +// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...)) +} + +// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...)) +} + +// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v)) +} + +// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...)) +} + +// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...)) +} + +// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v)) +} + +// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field. +func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v)) +} + +// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field. +func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v)) +} + +// DailyWindowStartIn applies the In predicate on the "daily_window_start" field. +func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...)) +} + +// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field. +func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...)) +} + +// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field. +func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v)) +} + +// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field. +func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v)) +} + +// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field. +func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v)) +} + +// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field. +func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v)) +} + +// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field. +func DailyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart)) +} + +// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field. +func DailyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart)) +} + +// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field. +func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field. +func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field. +func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...)) +} + +// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field. +func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...)) +} + +// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field. +func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field. +func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field. +func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field. +func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field. +func WeeklyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart)) +} + +// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field. +func WeeklyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart)) +} + +// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field. +func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field. +func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field. +func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...)) +} + +// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field. +func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...)) +} + +// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field. +func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field. +func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field. +func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field. +func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field. +func MonthlyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart)) +} + +// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field. +func MonthlyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.NotPredicates(p)) +} diff --git a/backend/ent/userplatformquota_create.go b/backend/ent/userplatformquota_create.go new file mode 100644 index 00000000..da6c3ce6 --- /dev/null +++ b/backend/ent/userplatformquota_create.go @@ -0,0 +1,1513 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaCreate is the builder for creating a UserPlatformQuota entity. +type UserPlatformQuotaCreate struct { + config + mutation *UserPlatformQuotaMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UserPlatformQuotaCreate) SetCreatedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableCreatedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *UserPlatformQuotaCreate) SetUpdatedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableUpdatedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetDeletedAt sets the "deleted_at" field. +func (_c *UserPlatformQuotaCreate) SetDeletedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetDeletedAt(v) + return _c +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDeletedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *UserPlatformQuotaCreate) SetUserID(v int64) *UserPlatformQuotaCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetPlatform sets the "platform" field. +func (_c *UserPlatformQuotaCreate) SetPlatform(v string) *UserPlatformQuotaCreate { + _c.mutation.SetPlatform(v) + return _c +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetDailyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetDailyLimitUsd(v) + return _c +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyLimitUsd(*v) + } + return _c +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyLimitUsd(v) + return _c +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyLimitUsd(*v) + } + return _c +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyLimitUsd(v) + return _c +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyLimitUsd(*v) + } + return _c +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetDailyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetDailyUsageUsd(v) + return _c +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyUsageUsd(*v) + } + return _c +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyUsageUsd(v) + return _c +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyUsageUsd(*v) + } + return _c +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyUsageUsd(v) + return _c +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyUsageUsd(*v) + } + return _c +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_c *UserPlatformQuotaCreate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetDailyWindowStart(v) + return _c +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyWindowStart(*v) + } + return _c +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyWindowStart(v) + return _c +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyWindowStart(*v) + } + return _c +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyWindowStart(v) + return _c +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyWindowStart(*v) + } + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *UserPlatformQuotaCreate) SetUser(v *User) *UserPlatformQuotaCreate { + return _c.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_c *UserPlatformQuotaCreate) Mutation() *UserPlatformQuotaMutation { + return _c.mutation +} + +// Save creates the UserPlatformQuota in the database. +func (_c *UserPlatformQuotaCreate) Save(ctx context.Context) (*UserPlatformQuota, error) { + if err := _c.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UserPlatformQuotaCreate) SaveX(ctx context.Context) *UserPlatformQuota { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserPlatformQuotaCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserPlatformQuotaCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UserPlatformQuotaCreate) defaults() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + if userplatformquota.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + if userplatformquota.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.DailyUsageUsd(); !ok { + v := userplatformquota.DefaultDailyUsageUsd + _c.mutation.SetDailyUsageUsd(v) + } + if _, ok := _c.mutation.WeeklyUsageUsd(); !ok { + v := userplatformquota.DefaultWeeklyUsageUsd + _c.mutation.SetWeeklyUsageUsd(v) + } + if _, ok := _c.mutation.MonthlyUsageUsd(); !ok { + v := userplatformquota.DefaultMonthlyUsageUsd + _c.mutation.SetMonthlyUsageUsd(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UserPlatformQuotaCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserPlatformQuota.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserPlatformQuota.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserPlatformQuota.user_id"`)} + } + if _, ok := _c.mutation.Platform(); !ok { + return &ValidationError{Name: "platform", err: errors.New(`ent: missing required field "UserPlatformQuota.platform"`)} + } + if v, ok := _c.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _, ok := _c.mutation.DailyUsageUsd(); !ok { + return &ValidationError{Name: "daily_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.daily_usage_usd"`)} + } + if _, ok := _c.mutation.WeeklyUsageUsd(); !ok { + return &ValidationError{Name: "weekly_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.weekly_usage_usd"`)} + } + if _, ok := _c.mutation.MonthlyUsageUsd(); !ok { + return &ValidationError{Name: "monthly_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.monthly_usage_usd"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserPlatformQuota.user"`)} + } + return nil +} + +func (_c *UserPlatformQuotaCreate) sqlSave(ctx context.Context) (*UserPlatformQuota, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UserPlatformQuotaCreate) createSpec() (*UserPlatformQuota, *sqlgraph.CreateSpec) { + var ( + _node = &UserPlatformQuota{config: _c.config} + _spec = sqlgraph.NewCreateSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(userplatformquota.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := _c.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + _node.Platform = value + } + if value, ok := _c.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + _node.DailyLimitUsd = &value + } + if value, ok := _c.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + _node.WeeklyLimitUsd = &value + } + if value, ok := _c.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + _node.MonthlyLimitUsd = &value + } + if value, ok := _c.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + _node.DailyUsageUsd = value + } + if value, ok := _c.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + _node.WeeklyUsageUsd = value + } + if value, ok := _c.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + _node.MonthlyUsageUsd = value + } + if value, ok := _c.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + _node.DailyWindowStart = &value + } + if value, ok := _c.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + _node.WeeklyWindowStart = &value + } + if value, ok := _c.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + _node.MonthlyWindowStart = &value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UserPlatformQuota.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserPlatformQuotaUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UserPlatformQuotaCreate) OnConflict(opts ...sql.ConflictOption) *UserPlatformQuotaUpsertOne { + _c.conflict = opts + return &UserPlatformQuotaUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UserPlatformQuotaCreate) OnConflictColumns(columns ...string) *UserPlatformQuotaUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UserPlatformQuotaUpsertOne{ + create: _c, + } +} + +type ( + // UserPlatformQuotaUpsertOne is the builder for "upsert"-ing + // one UserPlatformQuota node. + UserPlatformQuotaUpsertOne struct { + create *UserPlatformQuotaCreate + } + + // UserPlatformQuotaUpsert is the "OnConflict" setter. + UserPlatformQuotaUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsert) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateUpdatedAt() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsert) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDeletedAt() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsert) ClearDeletedAt() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDeletedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsert) SetUserID(v int64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateUserID() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldUserID) + return u +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsert) SetPlatform(v string) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldPlatform, v) + return u +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdatePlatform() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldPlatform) + return u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyLimitUsd, v) + return u +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyLimitUsd) + return u +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldDailyLimitUsd, v) + return u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearDailyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDailyLimitUsd) + return u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyLimitUsd, v) + return u +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyLimitUsd) + return u +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldWeeklyLimitUsd, v) + return u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldWeeklyLimitUsd) + return u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyLimitUsd, v) + return u +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyLimitUsd) + return u +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldMonthlyLimitUsd, v) + return u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldMonthlyLimitUsd) + return u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyUsageUsd, v) + return u +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyUsageUsd) + return u +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldDailyUsageUsd, v) + return u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyUsageUsd, v) + return u +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyUsageUsd) + return u +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldWeeklyUsageUsd, v) + return u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyUsageUsd, v) + return u +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyUsageUsd) + return u +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldMonthlyUsageUsd, v) + return u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsert) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyWindowStart, v) + return u +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyWindowStart) + return u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearDailyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDailyWindowStart) + return u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyWindowStart, v) + return u +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyWindowStart) + return u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearWeeklyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldWeeklyWindowStart) + return u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyWindowStart, v) + return u +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyWindowStart) + return u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearMonthlyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldMonthlyWindowStart) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertOne) UpdateNewValues() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(userplatformquota.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertOne) Ignore() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserPlatformQuotaUpsertOne) DoNothing() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserPlatformQuotaCreate.OnConflict +// documentation for more info. +func (u *UserPlatformQuotaUpsertOne) Update(set func(*UserPlatformQuotaUpsert)) *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserPlatformQuotaUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsertOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateUpdatedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsertOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDeletedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsertOne) ClearDeletedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsertOne) SetUserID(v int64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateUserID() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUserID() + }) +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsertOne) SetPlatform(v string) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetPlatform(v) + }) +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdatePlatform() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdatePlatform() + }) +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyLimitUsd(v) + }) +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyLimitUsd(v) + }) +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyLimitUsd() + }) +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearDailyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyLimitUsd() + }) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyLimitUsd(v) + }) +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyLimitUsd(v) + }) +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyLimitUsd() + }) +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyLimitUsd() + }) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyLimitUsd(v) + }) +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyLimitUsd(v) + }) +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyLimitUsd() + }) +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyLimitUsd() + }) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyUsageUsd(v) + }) +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyUsageUsd(v) + }) +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyUsageUsd() + }) +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyUsageUsd(v) + }) +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyUsageUsd(v) + }) +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyUsageUsd() + }) +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyUsageUsd(v) + }) +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyUsageUsd(v) + }) +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyUsageUsd() + }) +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyWindowStart(v) + }) +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyWindowStart() + }) +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearDailyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyWindowStart() + }) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyWindowStart(v) + }) +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyWindowStart() + }) +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyWindowStart() + }) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyWindowStart(v) + }) +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyWindowStart() + }) +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyWindowStart() + }) +} + +// Exec executes the query. +func (u *UserPlatformQuotaUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserPlatformQuotaCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UserPlatformQuotaUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// UserPlatformQuotaCreateBulk is the builder for creating many UserPlatformQuota entities in bulk. +type UserPlatformQuotaCreateBulk struct { + config + err error + builders []*UserPlatformQuotaCreate + conflict []sql.ConflictOption +} + +// Save creates the UserPlatformQuota entities in the database. +func (_c *UserPlatformQuotaCreateBulk) Save(ctx context.Context) ([]*UserPlatformQuota, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UserPlatformQuota, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UserPlatformQuotaMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UserPlatformQuotaCreateBulk) SaveX(ctx context.Context) []*UserPlatformQuota { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserPlatformQuotaCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserPlatformQuotaCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UserPlatformQuota.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserPlatformQuotaUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UserPlatformQuotaCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserPlatformQuotaUpsertBulk { + _c.conflict = opts + return &UserPlatformQuotaUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UserPlatformQuotaCreateBulk) OnConflictColumns(columns ...string) *UserPlatformQuotaUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UserPlatformQuotaUpsertBulk{ + create: _c, + } +} + +// UserPlatformQuotaUpsertBulk is the builder for "upsert"-ing +// a bulk of UserPlatformQuota nodes. +type UserPlatformQuotaUpsertBulk struct { + create *UserPlatformQuotaCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertBulk) UpdateNewValues() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(userplatformquota.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertBulk) Ignore() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserPlatformQuotaUpsertBulk) DoNothing() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserPlatformQuotaCreateBulk.OnConflict +// documentation for more info. +func (u *UserPlatformQuotaUpsertBulk) Update(set func(*UserPlatformQuotaUpsert)) *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserPlatformQuotaUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsertBulk) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateUpdatedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsertBulk) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDeletedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDeletedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsertBulk) SetUserID(v int64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateUserID() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUserID() + }) +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsertBulk) SetPlatform(v string) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetPlatform(v) + }) +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdatePlatform() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdatePlatform() + }) +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyLimitUsd(v) + }) +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyLimitUsd(v) + }) +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyLimitUsd() + }) +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDailyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyLimitUsd() + }) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyLimitUsd(v) + }) +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyLimitUsd(v) + }) +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyLimitUsd() + }) +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyLimitUsd() + }) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyLimitUsd(v) + }) +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyLimitUsd(v) + }) +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyLimitUsd() + }) +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyLimitUsd() + }) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyUsageUsd(v) + }) +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyUsageUsd(v) + }) +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyUsageUsd() + }) +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyUsageUsd(v) + }) +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyUsageUsd(v) + }) +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyUsageUsd() + }) +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyUsageUsd(v) + }) +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyUsageUsd(v) + }) +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyUsageUsd() + }) +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyWindowStart(v) + }) +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyWindowStart() + }) +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDailyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyWindowStart() + }) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyWindowStart(v) + }) +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyWindowStart() + }) +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearWeeklyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyWindowStart() + }) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyWindowStart(v) + }) +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyWindowStart() + }) +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearMonthlyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyWindowStart() + }) +} + +// Exec executes the query. +func (u *UserPlatformQuotaUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserPlatformQuotaCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserPlatformQuotaCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/userplatformquota_delete.go b/backend/ent/userplatformquota_delete.go new file mode 100644 index 00000000..1d80e31b --- /dev/null +++ b/backend/ent/userplatformquota_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity. +type UserPlatformQuotaDelete struct { + config + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// Where appends a list predicates to the UserPlatformQuotaDelete builder. +func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity. +type UserPlatformQuotaDeleteOne struct { + _d *UserPlatformQuotaDelete +} + +// Where appends a list predicates to the UserPlatformQuotaDelete builder. +func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{userplatformquota.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/userplatformquota_query.go b/backend/ent/userplatformquota_query.go new file mode 100644 index 00000000..c68fe904 --- /dev/null +++ b/backend/ent/userplatformquota_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities. +type UserPlatformQuotaQuery struct { + config + ctx *QueryContext + order []userplatformquota.OrderOption + inters []Interceptor + predicates []predicate.UserPlatformQuota + withUser *UserQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UserPlatformQuotaQuery builder. +func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first UserPlatformQuota entity from the query. +// Returns a *NotFoundError when no UserPlatformQuota was found. +func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{userplatformquota.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UserPlatformQuota ID from the query. +// Returns a *NotFoundError when no UserPlatformQuota ID was found. +func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{userplatformquota.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UserPlatformQuota entity is found. +// Returns a *NotFoundError when no UserPlatformQuota entities are found. +func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{userplatformquota.Label} + default: + return nil, &NotSingularError{userplatformquota.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query. +// Returns a *NotSingularError when more than one UserPlatformQuota ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{userplatformquota.Label} + default: + err = &NotSingularError{userplatformquota.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UserPlatformQuotaSlice. +func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]() + return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UserPlatformQuota IDs. +func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery { + if _q == nil { + return nil + } + return &UserPlatformQuotaQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]userplatformquota.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...), + withUser: _q.withUser.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UserPlatformQuota.Query(). +// GroupBy(userplatformquota.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UserPlatformQuotaGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = userplatformquota.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.UserPlatformQuota.Query(). +// Select(userplatformquota.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q} + sbuild.label = userplatformquota.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations. +func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !userplatformquota.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) { + var ( + nodes = []*UserPlatformQuota{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withUser != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UserPlatformQuota).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UserPlatformQuota{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UserPlatformQuota) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID) + for i := range fields { + if fields[i] != userplatformquota.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(userplatformquota.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(userplatformquota.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = userplatformquota.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities. +type UserPlatformQuotaGroupBy struct { + selector + build *UserPlatformQuotaQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities. +type UserPlatformQuotaSelect struct { + *UserPlatformQuotaQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v) +} + +func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/userplatformquota_update.go b/backend/ent/userplatformquota_update.go new file mode 100644 index 00000000..4e924cc8 --- /dev/null +++ b/backend/ent/userplatformquota_update.go @@ -0,0 +1,985 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities. +type UserPlatformQuotaUpdate struct { + config + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// Where appends a list predicates to the UserPlatformQuotaUpdate builder. +func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate { + _u.mutation.ClearDeletedAt() + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetPlatform sets the "platform" field. +func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate { + _u.mutation.SetPlatform(v) + return _u +} + +// SetNillablePlatform sets the "platform" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetPlatform(*v) + } + return _u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetDailyLimitUsd() + _u.mutation.SetDailyLimitUsd(v) + return _u +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyLimitUsd(*v) + } + return _u +} + +// AddDailyLimitUsd adds value to the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddDailyLimitUsd(v) + return _u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearDailyLimitUsd() + return _u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetWeeklyLimitUsd() + _u.mutation.SetWeeklyLimitUsd(v) + return _u +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyLimitUsd(*v) + } + return _u +} + +// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddWeeklyLimitUsd(v) + return _u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearWeeklyLimitUsd() + return _u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetMonthlyLimitUsd() + _u.mutation.SetMonthlyLimitUsd(v) + return _u +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyLimitUsd(*v) + } + return _u +} + +// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddMonthlyLimitUsd(v) + return _u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearMonthlyLimitUsd() + return _u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetDailyUsageUsd() + _u.mutation.SetDailyUsageUsd(v) + return _u +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyUsageUsd(*v) + } + return _u +} + +// AddDailyUsageUsd adds value to the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddDailyUsageUsd(v) + return _u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetWeeklyUsageUsd() + _u.mutation.SetWeeklyUsageUsd(v) + return _u +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyUsageUsd(*v) + } + return _u +} + +// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddWeeklyUsageUsd(v) + return _u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetMonthlyUsageUsd() + _u.mutation.SetMonthlyUsageUsd(v) + return _u +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyUsageUsd(*v) + } + return _u +} + +// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddMonthlyUsageUsd(v) + return _u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetDailyWindowStart(v) + return _u +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyWindowStart(*v) + } + return _u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearDailyWindowStart() + return _u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetWeeklyWindowStart(v) + return _u +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyWindowStart(*v) + } + return _u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearWeeklyWindowStart() + return _u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetMonthlyWindowStart(v) + return _u +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyWindowStart(*v) + } + return _u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearMonthlyWindowStart() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate { + return _u.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate { + _u.mutation.ClearUser() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) { + if err := _u.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UserPlatformQuotaUpdate) defaults() error { + if _, ok := _u.mutation.UpdatedAt(); !ok { + if userplatformquota.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserPlatformQuotaUpdate) check() error { + if v, ok := _u.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`) + } + return nil +} + +func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime) + } + if value, ok := _u.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + } + if value, ok := _u.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.DailyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.WeeklyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.MonthlyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + } + if _u.mutation.DailyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + } + if _u.mutation.WeeklyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + } + if _u.mutation.MonthlyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{userplatformquota.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity. +type UserPlatformQuotaUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDeletedAt() + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetPlatform sets the "platform" field. +func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne { + _u.mutation.SetPlatform(v) + return _u +} + +// SetNillablePlatform sets the "platform" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetPlatform(*v) + } + return _u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetDailyLimitUsd() + _u.mutation.SetDailyLimitUsd(v) + return _u +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyLimitUsd(*v) + } + return _u +} + +// AddDailyLimitUsd adds value to the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddDailyLimitUsd(v) + return _u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDailyLimitUsd() + return _u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetWeeklyLimitUsd() + _u.mutation.SetWeeklyLimitUsd(v) + return _u +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyLimitUsd(*v) + } + return _u +} + +// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddWeeklyLimitUsd(v) + return _u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearWeeklyLimitUsd() + return _u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetMonthlyLimitUsd() + _u.mutation.SetMonthlyLimitUsd(v) + return _u +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyLimitUsd(*v) + } + return _u +} + +// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddMonthlyLimitUsd(v) + return _u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearMonthlyLimitUsd() + return _u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetDailyUsageUsd() + _u.mutation.SetDailyUsageUsd(v) + return _u +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyUsageUsd(*v) + } + return _u +} + +// AddDailyUsageUsd adds value to the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddDailyUsageUsd(v) + return _u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetWeeklyUsageUsd() + _u.mutation.SetWeeklyUsageUsd(v) + return _u +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyUsageUsd(*v) + } + return _u +} + +// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddWeeklyUsageUsd(v) + return _u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetMonthlyUsageUsd() + _u.mutation.SetMonthlyUsageUsd(v) + return _u +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyUsageUsd(*v) + } + return _u +} + +// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddMonthlyUsageUsd(v) + return _u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetDailyWindowStart(v) + return _u +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyWindowStart(*v) + } + return _u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDailyWindowStart() + return _u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetWeeklyWindowStart(v) + return _u +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyWindowStart(*v) + } + return _u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearWeeklyWindowStart() + return _u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetMonthlyWindowStart(v) + return _u +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyWindowStart(*v) + } + return _u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearMonthlyWindowStart() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne { + return _u.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// Where appends a list predicates to the UserPlatformQuotaUpdate builder. +func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UserPlatformQuota entity. +func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) { + if err := _u.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UserPlatformQuotaUpdateOne) defaults() error { + if _, ok := _u.mutation.UpdatedAt(); !ok { + if userplatformquota.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserPlatformQuotaUpdateOne) check() error { + if v, ok := _u.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`) + } + return nil +} + +func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID) + for _, f := range fields { + if !userplatformquota.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != userplatformquota.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime) + } + if value, ok := _u.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + } + if value, ok := _u.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.DailyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.WeeklyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.MonthlyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + } + if _u.mutation.DailyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + } + if _u.mutation.WeeklyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + } + if _u.mutation.MonthlyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &UserPlatformQuota{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{userplatformquota.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/go.mod b/backend/go.mod index d8a9c437..c3101165 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -6,12 +6,14 @@ require ( connectrpc.com/connect v1.19.2 entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/alicebob/miniredis/v2 v2.38.0 github.com/alitto/pond/v2 v2.6.2 github.com/andybalholm/brotli v1.2.0 github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 + github.com/aws/smithy-go v1.24.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 @@ -74,7 +76,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect @@ -160,6 +161,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zclconf/go-cty v1.14.4 // indirect github.com/zclconf/go-cty-yaml v1.1.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 04a9d449..309d95c3 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -18,6 +18,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw= github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw= +github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw= +github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= @@ -362,6 +364,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index ee088d63..3d9ff81d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -651,6 +651,12 @@ type ProxyProbeConfig struct { type BillingConfig struct { CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` + // UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口。 + // 消费点: + // - billing_cache_service.cacheWriteWorker 异步累加 + // - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载 + // 读写两端必须共用同一 TTL,避免缓存生命周期不一致导致 quota 计数漂移。 + UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"` } type CircuitBreakerConfig struct { @@ -688,6 +694,9 @@ type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 // 注意:这不影响流式数据传输,只控制等待响应头的时间 ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` + // OpenAIResponseHeaderTimeout: OpenAI/Codex 上游等待响应头的超时时间(秒),0表示无超时 + // OpenAI/Codex 请求可能在上游排队较久;默认不使用通用响应头超时截断。 + OpenAIResponseHeaderTimeout int `mapstructure:"openai_response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 @@ -717,6 +726,8 @@ type GatewayConfig struct { OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` // NodeTLSProxy: Node.js TLS 代理配置 NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"` + // OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1) + OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"` // ImageConcurrency: 图片生成独立并发限制配置(默认关闭) ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"` @@ -815,6 +826,21 @@ type GatewayConfig struct { ContextCompression ContextCompressionConfig `mapstructure:"context_compression"` } +// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。 +// 默认启用 HTTP/2;在部分代理不兼容时按策略回退 HTTP/1.1。 +type GatewayOpenAIHTTP2Config struct { + // Enabled: 是否启用 OpenAI HTTP/2 优先策略 + Enabled bool `mapstructure:"enabled"` + // AllowProxyFallbackToHTTP1: HTTP/HTTPS 代理出现明确 H2 兼容错误时,临时回退 HTTP/1.1 + AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"` + // FallbackErrorThreshold: 回退窗口内累计多少次兼容错误后触发回退 + FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"` + // FallbackWindowSeconds: 统计兼容错误的时间窗口(秒) + FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"` + // FallbackTTLSeconds: 触发后回退 HTTP/1.1 的持续时间(秒) + FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"` +} + // UserMessageQueueConfig 用户消息串行队列配置 // 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 type UserMessageQueueConfig struct { @@ -1647,6 +1673,7 @@ func setDefaults() { viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30) viper.SetDefault("billing.circuit_breaker.half_open_requests", 3) + viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400) // Turnstile viper.SetDefault("turnstile.required", false) @@ -1847,6 +1874,7 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.openai_response_header_timeout", 0) viper.SetDefault("gateway.log_upstream_error_body", true) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.inject_beta_for_apikey", false) @@ -1902,6 +1930,12 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) + // OpenAI HTTP upstream protocol strategy + viper.SetDefault("gateway.openai_http2.enabled", true) + viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true) + viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2) + viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60) + viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600) viper.SetDefault("gateway.image_concurrency.enabled", false) viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0) viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject) @@ -2523,6 +2557,12 @@ func (c *Config) Validate() error { if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") } + if c.Gateway.ResponseHeaderTimeout < 0 { + return fmt.Errorf("gateway.response_header_timeout must be non-negative") + } + if c.Gateway.OpenAIResponseHeaderTimeout < 0 { + return fmt.Errorf("gateway.openai_response_header_timeout must be non-negative") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -2697,6 +2737,15 @@ func (c *Config) Validate() error { if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") } + if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative") + } if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 99fec46c..1eae5ed9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -163,6 +163,41 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { } } +func TestLoadDefaultOpenAIHTTP2Enabled(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + require.NoError(t, err) + require.True(t, cfg.Gateway.OpenAIHTTP2.Enabled) + require.True(t, cfg.Gateway.OpenAIHTTP2.AllowProxyFallbackToHTTP1) +} + +func TestLoadOpenAIHTTP2DisabledFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_HTTP2_ENABLED", "false") + + cfg, err := Load() + require.NoError(t, err) + require.False(t, cfg.Gateway.OpenAIHTTP2.Enabled) +} + +func TestLoadDefaultOpenAIResponseHeaderTimeoutUnlimited(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + require.NoError(t, err) + require.Equal(t, 0, cfg.Gateway.OpenAIResponseHeaderTimeout) +} + +func TestLoadOpenAIResponseHeaderTimeoutFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT", "1800") + + cfg, err := Load() + require.NoError(t, err) + require.Equal(t, 1800, cfg.Gateway.OpenAIResponseHeaderTimeout) +} + func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) { resetViperWithJWTSecret(t) t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0") @@ -1220,6 +1255,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 }, wantErr: "gateway.max_body_size", }, + { + name: "gateway response header timeout", + mutate: func(c *Config) { c.Gateway.ResponseHeaderTimeout = -1 }, + wantErr: "gateway.response_header_timeout", + }, + { + name: "gateway openai response header timeout", + mutate: func(c *Config) { c.Gateway.OpenAIResponseHeaderTimeout = -1 }, + wantErr: "gateway.openai_response_header_timeout", + }, { name: "gateway max idle conns", mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 }, @@ -1275,6 +1320,21 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 }, wantErr: "gateway.openai_ws.apikey_max_conns_factor", }, + { + name: "gateway openai http2 fallback threshold", + mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackErrorThreshold = -1 }, + wantErr: "gateway.openai_http2.fallback_error_threshold", + }, + { + name: "gateway openai http2 fallback window", + mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackWindowSeconds = -1 }, + wantErr: "gateway.openai_http2.fallback_window_seconds", + }, + { + name: "gateway openai http2 fallback ttl", + mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackTTLSeconds = -1 }, + wantErr: "gateway.openai_http2.fallback_ttl_seconds", + }, { name: "gateway stream data interval range", mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, diff --git a/backend/internal/domain/models_list_config.go b/backend/internal/domain/models_list_config.go new file mode 100644 index 00000000..3f050585 --- /dev/null +++ b/backend/internal/domain/models_list_config.go @@ -0,0 +1,7 @@ +package domain + +// GroupModelsListConfig controls the optional custom /v1/models response list. +type GroupModelsListConfig struct { + Enabled bool `json:"enabled"` + Models []string `json:"models,omitempty"` +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 267206e6..6803badb 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -982,6 +982,100 @@ func (h *AccountHandler) Refresh(c *gin.Context) { response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) } +// ApplyOAuthCredentialsRequest is the payload for persisting re-authorized OAuth credentials. +type ApplyOAuthCredentialsRequest struct { + Type string `json:"type" binding:"required,oneof=oauth setup-token"` + Credentials map[string]any `json:"credentials" binding:"required"` + Extra map[string]any `json:"extra"` +} + +// ApplyOAuthCredentials 将"重新授权"得到的新凭据原子落库。 +// POST /api/v1/admin/accounts/:id/apply-oauth-credentials +// +// 与通用 PUT /:id (Update) 接口的关键区别: +// - 仅接收 type / credentials / extra 三个字段(不接受 concurrency / rpm / quota_* 等可能误传的字段) +// - Extra 走 UpdateAccountExtra(JSONB key 级合并),**绝不**全量覆盖; +// 避免 base_rpm / window_cost_limit / max_sessions / quota_* / privacy_mode +// 等持久化配置在重新授权后丢失 +// - 内置 ClearError + InvalidateToken,避免前端额外两次调用, +// 并修复旧路径未失效 token 缓存导致重新授权后立即 401 的隐性 bug +// +// 与 /refresh 的区别:/refresh 用现有 refresh_token 换 access_token(无用户交互), +// 本接口承接前端完成完整 OAuth 流程后的落库步骤。 +func (h *AccountHandler) ApplyOAuthCredentials(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + var req ApplyOAuthCredentialsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + ctx := c.Request.Context() + + // 预检查账号存在 + OAuth 类型(与 Refresh handler 语义一致,提供更友好的错误信息)。 + existing, err := h.adminService.GetAccount(ctx, accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + if !existing.IsOAuth() { + response.ErrorFrom(c, infraerrors.BadRequest("NOT_OAUTH", "cannot apply oauth credentials to non-OAuth account")) + return + } + + updatedAccount, err := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{ + Type: req.Type, + Credentials: req.Credentials, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 增量合并 Extra(JSONB key 级 merge,绝不覆盖 base_rpm / window_cost_limit / + // max_sessions / quota_* / privacy_mode 等持久化键)。 + // best-effort:失败仅记日志;下方 ClearAccountError 会从 DB 重新读取最新 account, + // 因此响应里的 extra 始终以 DB 为准——这里不需要手动维护内存快照。 + if len(req.Extra) > 0 { + if extraErr := h.adminService.UpdateAccountExtra(ctx, accountID, req.Extra); extraErr != nil { + extraKeys := make([]string, 0, len(req.Extra)) + for k := range req.Extra { + extraKeys = append(extraKeys, k) + } + slog.Error("apply_oauth_credentials.update_extra_failed", + "account_id", accountID, + "extra_keys", extraKeys, + "err", extraErr, + ) + } + } + + if cleared, clearErr := h.adminService.ClearAccountError(ctx, accountID); clearErr != nil { + slog.Warn("apply_oauth_credentials.clear_error_failed", + "account_id", accountID, + "err", clearErr, + ) + } else if cleared != nil { + updatedAccount = cleared + } + + if h.tokenCacheInvalidator != nil && updatedAccount.IsOAuth() { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil { + slog.Warn("apply_oauth_credentials.invalidate_token_failed", + "account_id", accountID, + "err", invalidateErr, + ) + } + } + + response.Success(c, h.buildAccountResponseWithRuntime(ctx, updatedAccount)) +} + // GetStats handles getting account statistics // GET /api/v1/admin/accounts/:id/stats func (h *AccountHandler) GetStats(c *gin.Context) { diff --git a/backend/internal/handler/admin/account_handler_list_test.go b/backend/internal/handler/admin/account_handler_list_test.go new file mode 100644 index 00000000..4d628365 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_list_test.go @@ -0,0 +1,52 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountListRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.GET("/api/v1/admin/accounts", handler.List) + return router, adminSvc +} + +func TestAccountHandlerListIncludesCreatedAt(t *testing.T) { + router, adminSvc := setupAccountListRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts?page=1&page_size=20&sort_by=created_at&sort_order=desc", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "created_at", adminSvc.lastListAccounts.sortBy) + + var payload struct { + Data struct { + Items []struct { + ID int64 `json:"id"` + CreatedAt string `json:"created_at"` + } `json:"items"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + require.Len(t, payload.Data.Items, 1) + + createdAt := payload.Data.Items[0].CreatedAt + require.NotEmpty(t, createdAt) + require.True(t, strings.HasSuffix(createdAt, "Z"), "created_at should be serialized as UTC") + parsed, err := time.Parse(time.RFC3339Nano, createdAt) + require.NoError(t, err) + _, offset := parsed.Zone() + require.Equal(t, 0, offset) +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index ddeaab02..bffddc8a 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router := gin.New() adminSvc := newStubAdminService() - userHandler := NewUserHandler(adminSvc, nil) + userHandler := NewUserHandler(adminSvc, nil, nil, nil) groupHandler := NewGroupHandler(adminSvc, nil, nil) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc, nil) @@ -33,6 +33,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/groups", groupHandler.List) router.GET("/api/v1/admin/groups/all", groupHandler.GetAll) + router.GET("/api/v1/admin/groups/:id/models-list-candidates", groupHandler.GetModelsListCandidates) router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID) router.POST("/api/v1/admin/groups", groupHandler.Create) router.PUT("/api/v1/admin/groups/:id", groupHandler.Update) @@ -177,6 +178,12 @@ func TestGroupHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/0/models-list-candidates?platform=openai", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "gpt-5.5") + body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"}) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body)) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 2fef94f1..fd0ec459 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -24,6 +24,7 @@ type stubAdminService struct { updatedProxyIDs []int64 updatedProxies []*service.UpdateProxyInput testedProxyIDs []int64 + getUserErr error createAccountErr error updateAccountErr error bulkUpdateAccountErr error @@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi } func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) { + if s.getUserErr != nil { + return nil, s.getUserErr + } for i := range s.users { if s.users[i].ID == id { return &s.users[i], nil @@ -261,6 +265,13 @@ func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Gro return &group, nil } +func (s *stubAdminService) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) { + if platform == service.PlatformOpenAI { + return []string{"gpt-5.5", "gpt-5.4"}, nil + } + return []string{"claude-sonnet-4-6"}, nil +} + func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) { group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive} return &group, nil @@ -345,6 +356,10 @@ func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *s return &account, nil } +func (s *stubAdminService) UpdateAccountExtra(ctx context.Context, id int64, updates map[string]any) error { + return nil +} + func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error { return nil } diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go index defcd29d..e11ea6eb 100644 --- a/backend/internal/handler/admin/content_moderation_handler.go +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -34,6 +34,7 @@ type contentModerationConfigRequest struct { AllGroups *bool `json:"all_groups"` GroupIDs *[]int64 `json:"group_ids"` RecordNonHits *bool `json:"record_non_hits"` + Thresholds *map[string]float64 `json:"thresholds"` WorkerCount *int `json:"worker_count"` QueueSize *int `json:"queue_size"` BlockStatus *int `json:"block_status"` @@ -94,6 +95,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { AllGroups: req.AllGroups, GroupIDs: req.GroupIDs, RecordNonHits: req.RecordNonHits, + Thresholds: req.Thresholds, WorkerCount: req.WorkerCount, QueueSize: req.QueueSize, BlockStatus: req.BlockStatus, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 17b9555f..de2dc4d0 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -113,6 +113,7 @@ type CreateGroupRequest struct { RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig service.GroupModelsListConfig `json:"models_list_config"` // 分组 RPM 上限(0 = 不限制) RPMLimit int `json:"rpm_limit"` // 从指定分组复制账号(创建后自动绑定) @@ -153,6 +154,7 @@ type UpdateGroupRequest struct { RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig *service.GroupModelsListConfig `json:"models_list_config"` // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动 RPMLimit *int `json:"rpm_limit"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) @@ -238,6 +240,28 @@ func (h *GroupHandler) GetByID(c *gin.Context) { response.Success(c, dto.GroupFromServiceAdmin(group)) } +// GetModelsListCandidates handles getting candidate model IDs for custom /v1/models list. +// GET /api/v1/admin/groups/:id/models-list-candidates +func (h *GroupHandler) GetModelsListCandidates(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || groupID < 0 { + response.BadRequest(c, "Invalid group ID") + return + } + + models, err := h.adminService.GetGroupModelsListCandidates( + c.Request.Context(), + groupID, + c.Query("platform"), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"models": models}) +} + // Create handles creating a new group // POST /api/v1/admin/groups func (h *GroupHandler) Create(c *gin.Context) { @@ -275,6 +299,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + ModelsListConfig: req.ModelsListConfig, RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) @@ -330,6 +355,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + ModelsListConfig: req.ModelsListConfig, RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 36a60c23..3c7fe581 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) } + // Default platform quotas(JSON map) + if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil { + slog.Error("default_platform_quotas_get_failed", "error", err) + } else { + payload.DefaultPlatformQuotas = platformQuotas + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } @@ -637,6 +644,18 @@ type UpdateSettingsRequest struct { // OpenAI fast/flex policy (optional, only updated when provided) OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` + + // 系统全局 platform quota 默认值(整体替换语义:nil = 不修改,non-nil = 整体覆盖)。 + DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"` + + // auth-source 层 platform quota 覆盖(override 语义:nil = 不修改,non-nil = 整体覆盖该 source 的 quota 配置)。 + AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"` + AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"` + AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"` + AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"` + AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"` + AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"` + AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"` } // UpdateSettings 更新系统设置 @@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ + // 系统全局 platform quota 默认值(整体替换语义) + DefaultPlatformQuotas: req.DefaultPlatformQuotas, + RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, @@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { }(), } + // req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值); + // non-nil(含 empty map)表示整体覆盖:empty map = 清空该 source 的所有 quota 配置。 authSourceDefaults := &service.AuthSourceDefaultSettings{ Email: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), @@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas), }, LinuxDo: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), @@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas), }, OIDC: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), @@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas), }, WeChat: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), @@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas), }, GitHub: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance), @@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas), }, Google: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance), @@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas), }, DingTalk: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance), @@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas), }, ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), } @@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } else if fastPolicy != nil { payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) } + + // Default platform quotas(JSON map)—— 与 GetSettings 一致,避免保存后响应缺失该字段 + if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil { + slog.Error("default_platform_quotas_get_failed", "error", err) + } else { + payload.DefaultPlatformQuotas = platformQuotas + } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } @@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.RiskControlEnabled != after.RiskControlEnabled { changed = append(changed, "risk_control_enabled") } + // Default platform quotas(JSON map,整体比较) + if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) { + changed = append(changed, service.SettingKeyDefaultPlatformQuotas) + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } @@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind { changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind") } + // Platform quotas diff:整体替换语义,发单个 JSON key。 + if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) { + changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name)) + } } if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup { changed = append(changed, "force_email_on_third_party_signup") @@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, return result } +// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义: +// nil = 请求未包含该字段(保留 fallback),non-nil(含 empty map)= 整体覆盖。 +// 注意:JSON null 与字段省略等价——两者均反序列化为 nil map,因此都保留旧值; +// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。 +func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting { + if value == nil { + return fallback + } + return value +} + func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { data := make(map[string]any) raw, err := json.Marshal(settings) @@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind + data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas + data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas + data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas + data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas + data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas + data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas + data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup return data @@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) } return placeholders } + +// equalNullableFloat compares two *float64 values treating nil as a distinct case. +func equalNullableFloat(a, b *float64) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting. +func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 { + if s == nil { + return nil + } + switch win { + case "daily": + return s.DailyLimitUSD + case "weekly": + return s.WeeklyLimitUSD + case "monthly": + return s.MonthlyLimitUSD + } + return nil +} + +// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots. +func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool { + for _, platform := range service.AllowedQuotaPlatforms { + b := before[platform] + a := after[platform] + if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) { + return false + } + if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) { + return false + } + if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) { + return false + } + } + return true +} diff --git a/backend/internal/handler/admin/setting_handler_platform_quota_test.go b/backend/internal/handler/admin/setting_handler_platform_quota_test.go new file mode 100644 index 00000000..273b3442 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_platform_quota_test.go @@ -0,0 +1,188 @@ +//go:build unit + +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) { + five := 5.0 + ten := 10.0 + before := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + after := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + } + + changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{}) + found := false + for _, key := range changed { + if key == service.SettingKeyDefaultPlatformQuotas { + found = true + break + } + } + if !found { + t.Errorf("expected change detection for default platform quotas, got %v", changed) + } +} + +func TestDiffSettings_NoChangeWhenEqual(t *testing.T) { + five := 5.0 + before := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + after := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + + changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{}) + for _, key := range changed { + if key == service.SettingKeyDefaultPlatformQuotas { + t.Error("equal values should not be detected as changed") + } + } +} + +func TestEqualNullableFloat(t *testing.T) { + five := 5.0 + five2 := 5.0 + ten := 10.0 + cases := []struct { + a, b *float64 + want bool + }{ + {nil, nil, true}, + {&five, nil, false}, + {nil, &five, false}, + {&five, &five2, true}, + {&five, &ten, false}, + } + for _, c := range cases { + if got := equalNullableFloat(c.a, c.b); got != c.want { + t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want) + } + } +} + +func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) { + five := 5.0 + ten := 10.0 + before := map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + } + after := map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + } + if equalPlatformQuotaSettings(before, after) { + t.Error("expected unequal") + } +} + +func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) { + five := 5.0 + ten := 10.0 + before := &service.AuthSourceDefaultSettings{ + LinuxDo: service.ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + }, + } + after := &service.AuthSourceDefaultSettings{ + LinuxDo: service.ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + }, + } + + changed := appendAuthSourceDefaultChanges([]string{}, before, after) + // 改动 B5:整体替换语义,审计 log 发单个 JSON key,而非展开 84 个扁平 key。 + key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo") + found := false + for _, k := range changed { + if k == key { + found = true + break + } + } + if !found { + t.Errorf("expected %q in changed, got %v", key, changed) + } +} + +// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复: +// PUT 发 auth_source_default_email_platform_quotas,GET 能读回相同值(端到端往返)。 +func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) + + // PUT:发 email platform quota(openai monthly=20) + putBody := map[string]any{ + "auth_source_default_email_platform_quotas": map[string]any{ + "openai": map[string]any{ + "monthly": 20, + }, + }, + } + rawBody, err := json.Marshal(putBody) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + handler.UpdateSettings(c) + require.Equal(t, http.StatusOK, rec.Code) + + // 验证 DB 中写入了 JSON key + jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email") + require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB") + + // GET:验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20 + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + handler.GetSettings(c2) + require.Equal(t, http.StatusOK, rec2.Code) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + + emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any) + require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map") + openaiPQ, ok := emailPQ["openai"].(map[string]any) + require.True(t, ok, "expected openai entry in email platform quotas") + monthly, ok := openaiPQ["monthly"].(float64) + require.True(t, ok, "expected monthly to be float64") + require.Equal(t, float64(20), monthly, "expected openai monthly=20") +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index db35472e..32a21692 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -2,10 +2,15 @@ package admin import ( "context" + "errors" + "log/slog" + "math" "strconv" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -20,15 +25,24 @@ type UserWithConcurrency struct { // UserHandler handles admin user management type UserHandler struct { - adminService service.AdminService - concurrencyService *service.ConcurrencyService + adminService service.AdminService + concurrencyService *service.ConcurrencyService + userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view + billingCache service.BillingCache // T17/T18 缓存失效(PUT/POST 路径) } // NewUserHandler creates a new admin user handler -func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler { +func NewUserHandler( + adminService service.AdminService, + concurrencyService *service.ConcurrencyService, + userPlatformQuotaRepo service.UserPlatformQuotaRepository, + billingCache service.BillingCache, +) *UserHandler { return &UserHandler{ - adminService: adminService, - concurrencyService: concurrencyService, + adminService: adminService, + concurrencyService: concurrencyService, + userPlatformQuotaRepo: userPlatformQuotaRepo, + billingCache: billingCache, } } @@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) { } response.Success(c, gin.H{"affected": affected}) } + +// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas +// admin 视角:D14 lazy 归零 + 暴露 *_window_start 调试字段 +func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) { + idStr := c.Param("id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid user id") + return + } + if h.userPlatformQuotaRepo == nil { + response.Success(c, map[string]any{"platform_quotas": []any{}}) + return + } + // 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。 + if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil { + response.ErrorFrom(c, err) + return + } + records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + out := make([]map[string]any, 0, len(records)) + for _, r := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + +// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas. +type UpdateUserPlatformQuotasRequest struct { + Quotas []PlatformQuotaInput `json:"quotas" binding:"required"` +} + +// PlatformQuotaInput 单平台限额输入;limit 字段为 nil 表示不限制。 +type PlatformQuotaInput struct { + Platform string `json:"platform" binding:"required"` + DailyLimitUSD *float64 `json:"daily_limit_usd"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` +} + +// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。 + +// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas +// 全量替换该用户所有平台限额。 +func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) { + if h.userPlatformQuotaRepo == nil { + response.Error(c, 503, "platform quota service not available") + return + } + + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req UpdateUserPlatformQuotasRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.Quotas) > 4 { + response.BadRequest(c, "quotas length must be <= 4") + return + } + seen := make(map[string]struct{}, len(req.Quotas)) + for _, q := range req.Quotas { + if !service.IsAllowedQuotaPlatform(q.Platform) { + response.BadRequest(c, "invalid platform: "+q.Platform) + return + } + if _, dup := seen[q.Platform]; dup { + response.BadRequest(c, "duplicate platform: "+q.Platform) + return + } + seen[q.Platform] = struct{}{} + // daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义: + // nil / not set → 无限额(完全放行) + // 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立) + // > 0 → USD 限额上限 + // 拦截 NaN / ±Inf:客户端可发送超大数(如 1e308 × 2)使 JSON 反序列化得到 +Inf, + // 进入 DB 后 cache check 中 usage >= limit 永不成立,limit 等同失效。 + for _, f := range []struct { + name string + val *float64 + }{ + {"daily_limit_usd", q.DailyLimitUSD}, + {"weekly_limit_usd", q.WeeklyLimitUSD}, + {"monthly_limit_usd", q.MonthlyLimitUSD}, + } { + if f.val == nil { + continue + } + v := *f.val + if v < 0 { + response.BadRequest(c, f.name+" must be >= 0") + return + } + if math.IsNaN(v) || math.IsInf(v, 0) { + response.BadRequest(c, f.name+" must be a finite number") + return + } + } + } + + records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas)) + for _, q := range req.Quotas { + records = append(records, service.UserPlatformQuotaRecord{ + UserID: userID, + Platform: q.Platform, + DailyLimitUSD: q.DailyLimitUSD, + WeeklyLimitUSD: q.WeeklyLimitUSD, + MonthlyLimitUSD: q.MonthlyLimitUSD, + }) + } + + ctx := c.Request.Context() + // 校验用户是否存在,避免 FK 违反导致 500;用户不存在时返回 404。 + if _, err := h.adminService.GetUser(ctx, userID); err != nil { + response.ErrorFrom(c, err) + return + } + // 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。 + // ListByUser 失败不阻断主操作(best-effort),仅记录降级 warn。 + beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if beforeErr != nil { + slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr) + } + if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil { + response.ErrorFrom(c, err) + return + } + + beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords)) + for _, r := range beforeRecords { + beforeByPlatform[r.Platform] = r + } + afterPlatforms := make(map[string]struct{}, len(records)) + for _, r := range records { + afterPlatforms[r.Platform] = struct{}{} + } + changes := make([]map[string]any, 0, len(records)) + for _, r := range records { + entry := map[string]any{ + "platform": r.Platform, + "daily_limit_usd": r.DailyLimitUSD, + "weekly_limit_usd": r.WeeklyLimitUSD, + "monthly_limit_usd": r.MonthlyLimitUSD, + } + if prev, ok := beforeByPlatform[r.Platform]; ok { + entry["before_daily_limit_usd"] = prev.DailyLimitUSD + entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD + entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD + } + changes = append(changes, entry) + } + // 补 removed 条目:before 存在但 after 缺失 = 该平台被软删除。 + // 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。 + for _, prev := range beforeRecords { + if _, kept := afterPlatforms[prev.Platform]; kept { + continue + } + changes = append(changes, map[string]any{ + "platform": prev.Platform, + "removed": true, + "before_daily_limit_usd": prev.DailyLimitUSD, + "before_weekly_limit_usd": prev.WeeklyLimitUSD, + "before_monthly_limit_usd": prev.MonthlyLimitUSD, + }) + } + // before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段; + // false 时所有 entry 都会缺失 before_*_limit_usd,仅有 after 视图。 + slog.Info("admin.quota_updated", + "actor_admin_id", getAdminIDFromContext(c), + "target_user_id", userID, + "platform_count", len(records), + "before_snapshot_available", beforeErr == nil, + "changes", changes) + + // 失效 cache:对全部允许的 platform 统一 invalidate。 + // Trade-off:精确失效(仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser, + // 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素, + // 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。 + if h.billingCache != nil { + for _, p := range service.AllowedQuotaPlatforms { + if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil { + slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err) + } + } + } + + // 返回最新状态 + now := time.Now().UTC() + records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]map[string]any, 0, len(records2)) + for i := range records2 { + out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + +// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset. +type ResetUserPlatformQuotaWindowRequest struct { + Platform string `json:"platform" binding:"required"` + Window string `json:"window" binding:"required"` +} + +var allowedWindowsForQuotaReset = map[string]struct{}{ + "daily": {}, + "weekly": {}, + "monthly": {}, +} + +// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset +// 立即归零指定 (platform, window) 的用量并更新 window_start。 +func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) { + if h.userPlatformQuotaRepo == nil { + response.Error(c, 503, "platform quota service not available") + return + } + + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req ResetUserPlatformQuotaWindowRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !service.IsAllowedQuotaPlatform(req.Platform) { + response.BadRequest(c, "invalid platform: "+req.Platform) + return + } + if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok { + response.BadRequest(c, "invalid window: "+req.Window) + return + } + + ctx := c.Request.Context() + // 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。 + if _, err := h.adminService.GetUser(ctx, userID); err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil { + if errors.Is(err, service.ErrUserPlatformQuotaNotFound) { + response.NotFound(c, "user platform quota not found") + return + } + response.ErrorFrom(c, err) + return + } + + slog.Info("admin.quota_window_reset", + "actor_admin_id", getAdminIDFromContext(c), + "target_user_id", userID, + "platform", req.Platform, + "window", req.Window) + + if h.billingCache != nil { + if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil { + slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err) + } + } + + records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]map[string]any, 0, len(records)) + for i := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go index bfba2408..a9fdc48a 100644 --- a/backend/internal/handler/admin/user_handler_activity_test.go +++ b/backend/internal/handler/admin/user_handler_activity_test.go @@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) { UpdatedAt: lastLoginAt, }, } - handler := NewUserHandler(adminSvc, nil) + handler := NewUserHandler(adminSvc, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) { UpdatedAt: lastLoginAt, }, } - handler := NewUserHandler(adminSvc, nil) + handler := NewUserHandler(adminSvc, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) diff --git a/backend/internal/handler/admin/user_platform_quota_admin_test.go b/backend/internal/handler/admin/user_platform_quota_admin_test.go new file mode 100644 index 00000000..fe33d36c --- /dev/null +++ b/backend/internal/handler/admin/user_platform_quota_admin_test.go @@ -0,0 +1,301 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository,捕获 UpsertForUser 调用。 +type upsertCapturingQuotaRepo struct { + service.UserPlatformQuotaRepository + listRecords []service.UserPlatformQuotaRecord + listErr error + upsertCalls []upsertCall + upsertErr error + resetCalls []resetCall + resetErr error +} + +type upsertCall struct { + userID int64 + records []service.UserPlatformQuotaRecord +} +type resetCall struct { + userID int64 + platform string + window string + newStart time.Time +} + +func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return r.listRecords, r.listErr +} +func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + cloned := make([]service.UserPlatformQuotaRecord, len(records)) + copy(cloned, records) + r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned}) + return r.upsertErr +} +func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error { + r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart}) + return r.resetErr +} + +// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。 +type billingCacheStub struct { + service.BillingCache + deleteCalls []deleteCall + deleteErr error +} + +type deleteCall struct { + userID int64 + platform string +} + +func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error { + b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform}) + return b.deleteErr +} + +func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler { + return &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: newStubAdminService(), + } +} + +func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + c.Params = []gin.Param{{Key: "id", Value: "42"}} + return c, w +} + +func TestUpdateUserPlatformQuotas_Success(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + + body := `{"quotas":[ + {"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0}, + {"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if len(repo.upsertCalls) != 1 { + t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls)) + } + if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 { + t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0]) + } + // 缓存失效:请求中 2 个 platform + 软删除的 2 个 platform(gemini, antigravity)= 4 次 + if len(cache.deleteCalls) != 4 { + t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls) + } +} + +func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[ + {"platform":"anthropic","daily_limit_usd":1}, + {"platform":"anthropic","daily_limit_usd":2} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[ + {"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) { + repo := &upsertCapturingQuotaRepo{ + listRecords: []service.UserPlatformQuotaRecord{ + {UserID: 42, Platform: "anthropic"}, + }, + } + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if !strings.Contains(w.Body.String(), `"platform_quotas"`) { + t.Errorf("response should contain platform_quotas array: %s", w.Body.String()) + } +} + +// ───────── T4: Reset 测试 ───────── + +func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + c.Params = []gin.Param{{Key: "id", Value: "42"}} + return c, w +} + +func TestResetUserPlatformQuotaWindow_Success(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + body := `{"platform":"anthropic","window":"daily"}` + c, w := postReq(t, body) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if len(repo.resetCalls) != 1 { + t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls)) + } + if repo.resetCalls[0].userID != 42 || + repo.resetCalls[0].platform != "anthropic" || + repo.resetCalls[0].window != "daily" { + t.Errorf("unexpected reset call: %+v", repo.resetCalls[0]) + } + if len(cache.deleteCalls) != 1 || + cache.deleteCalls[0].userID != 42 || + cache.deleteCalls[0].platform != "anthropic" { + t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls) + } +} + +func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"unknown","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) { + // handler 检查 service.ErrUserPlatformQuotaNotFound(由 adapter 包装而来) + repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound} + h := buildTestHandler(repo, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) { + repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code < 500 { + t.Errorf("expected 5xx, got %d", w.Code) + } + // 返回 JSON 错误响应 + var body2 map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil { + t.Errorf("expected JSON error body, got: %s", w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + adminSvc := newStubAdminService() + adminSvc.getUserErr = service.ErrUserNotFound + h := &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: adminSvc, + } + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + adminSvc := newStubAdminService() + adminSvc.getUserErr = service.ErrUserNotFound + h := &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: adminSvc, + } + c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/admin/user_platform_quotas_handler_test.go b/backend/internal/handler/admin/user_platform_quotas_handler_test.go new file mode 100644 index 00000000..1927a6e8 --- /dev/null +++ b/backend/internal/handler/admin/user_platform_quotas_handler_test.go @@ -0,0 +1,124 @@ +//go:build unit + +package admin + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type fakeQuotaRepoForAdmin struct { + service.UserPlatformQuotaRepository + records []service.UserPlatformQuotaRecord + err error +} + +func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return f.records, f.err +} + +func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context { + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodGet, "/", nil) + c.Request = req + return c +} + +func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) { + start := time.Now().Add(-1 * time.Hour) + repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{ + UserID: 99, Platform: "anthropic", + DailyUsageUSD: 1.0, DailyWindowStart: &start, + }}} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "99"}} + h.GetUserPlatformQuotas(c) + + if w.Code != 200 { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"daily_window_start"`) { + t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String()) + } +} + +func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "abc"}} + h.GetUserPlatformQuotas(c) + if w.Code < 400 || w.Code >= 500 { + t.Errorf("invalid id should yield 4xx, got %d", w.Code) + } +} + +func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) { + repo := &fakeQuotaRepoForAdmin{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "99"}} + h.GetUserPlatformQuotas(c) + if w.Code != 200 { + t.Errorf("empty list should be 200, got %d", w.Code) + } + var body map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + data, ok := body["data"].(map[string]any) + if !ok { + t.Fatalf("response missing data object: %v", body) + } + quotas, ok := data["platform_quotas"].([]any) + if !ok { + t.Fatalf("data.platform_quotas missing or wrong type: %v", data) + } + if len(quotas) != 0 { + t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas) + } +} + +func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "1"}} + h.GetUserPlatformQuotas(c) + if w.Code != 200 { + t.Errorf("nil repo should return 200 empty, got %d", w.Code) + } +} + +// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404 +// (与 PUT / POST reset 端点行为一致;review fix:原实现返回空数组会让 admin 界面误判用户存在) +func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + repo := &fakeQuotaRepoForAdmin{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "999"}} + h.GetUserPlatformQuotas(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 4081b9e4..70fb160a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates ( nil, options.defaultSubAssigner, affiliateService, + nil, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go index f1c6d87d..922bf9c3 100644 --- a/backend/internal/handler/auth_session_revocation_test.go +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) handler := &AuthHandler{authService: authService} recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index b3c7786d..e291b7e7 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, nil, nil, nil, + nil, ) return &AuthHandler{ diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2c71be9d..51a11ea7 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -147,6 +147,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { MCPXMLInject: g.MCPXMLInject, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + ModelsListConfig: g.ModelsListConfig, SupportedModelScopes: g.SupportedModelScopes, AccountCount: g.AccountCount, ActiveAccountCount: g.ActiveAccountCount, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index fac60573..eecf98ac 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -3,6 +3,8 @@ package dto import ( "encoding/json" "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" ) // CustomMenuItem represents a user-configured custom menu entry. @@ -246,6 +248,9 @@ type SystemSettings struct { // OpenAI fast/flex policy OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` + + // 系统全局默认平台配额(key = platform,nil/缺省 = 不限制) + DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 31828375..b1841c62 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -138,6 +138,7 @@ type AdminGroup struct { // OpenAI Messages 调度配置(仅 openai 平台使用) DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index 16d97908..a2ba0f36 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -17,6 +17,7 @@ import ( const ( EndpointMessages = "/v1/messages" EndpointChatCompletions = "/v1/chat/completions" + EndpointEmbeddings = "/v1/embeddings" EndpointResponses = "/v1/responses" EndpointImagesGenerations = "/v1/images/generations" EndpointImagesEdits = "/v1/images/edits" @@ -42,6 +43,8 @@ const ( func NormalizeInboundEndpoint(path string) string { path = strings.TrimSpace(path) switch { + case strings.Contains(path, EndpointEmbeddings): + return EndpointEmbeddings case strings.Contains(path, EndpointChatCompletions): return EndpointChatCompletions case strings.Contains(path, EndpointMessages): @@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { switch platform { case service.PlatformOpenAI: - if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { + if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { return inbound } // OpenAI forwards everything to the Responses API. diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index 369c5fa7..42b6d6e7 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) { // Direct canonical paths. {"/v1/messages", EndpointMessages}, {"/v1/chat/completions", EndpointChatCompletions}, + {"/v1/embeddings", EndpointEmbeddings}, {"/v1/responses", EndpointResponses}, {"/v1/images/generations", EndpointImagesGenerations}, {"/v1/images/edits", EndpointImagesEdits}, @@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { {"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"}, {"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses}, {"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses}, + {"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings}, {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations}, {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits}, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 38dad596..ada1a499 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/http" "strconv" "strings" @@ -253,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 2. 【新增】Wait后二次检查余额/订阅 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -533,10 +534,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, @@ -825,6 +828,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Beta policy block: return 400 immediately, no failover var betaBlockedErr *service.BetaBlockedError if errors.As(err, &betaBlockedErr) { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message) return } @@ -855,7 +859,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil { status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { c.Header("Retry-After", strconv.Itoa(retryAfter)) @@ -960,10 +964,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, + QuotaPlatform: quotaPlatform, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, @@ -1015,22 +1021,14 @@ func (h *GatewayHandler) Models(c *gin.Context) { // Get available models from account configurations for the selected group platform. availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform) + if apiKey != nil && apiKey.Group != nil && apiKey.Group.CustomModelsListEnabled() { + availableModels = filterModelsByCustomList(availableModels, defaultModelIDsForPlatform(platform), apiKey.Group.ModelsListConfig.Models) + writeCustomModelsList(c, platform, availableModels) + return + } if len(availableModels) > 0 { - // Build model list from whitelist - models := make([]claude.Model, 0, len(availableModels)) - for _, modelID := range availableModels { - models = append(models, claude.Model{ - ID: modelID, - Type: "model", - DisplayName: modelID, - CreatedAt: "2024-01-01T00:00:00Z", - }) - } - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": models, - }) + writeModelsList(c, availableModels) return } @@ -1057,6 +1055,134 @@ func (h *GatewayHandler) Models(c *gin.Context) { }) } +func writeModelsList(c *gin.Context, modelIDs []string) { + models := make([]claude.Model, 0, len(modelIDs)) + for _, modelID := range modelIDs { + models = append(models, claude.Model{ + ID: modelID, + Type: "model", + DisplayName: modelID, + CreatedAt: "2024-01-01T00:00:00Z", + }) + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) +} + +func writeCustomModelsList(c *gin.Context, platform string, modelIDs []string) { + if platform == service.PlatformOpenAI { + writeOpenAIModelsList(c, modelIDs) + return + } + writeModelsList(c, modelIDs) +} + +func writeOpenAIModelsList(c *gin.Context, modelIDs []string) { + defaultsByID := make(map[string]openai.Model, len(openai.DefaultModels)) + for _, model := range openai.DefaultModels { + defaultsByID[model.ID] = model + } + + models := make([]openai.Model, 0, len(modelIDs)) + for _, modelID := range modelIDs { + if model, ok := defaultsByID[modelID]; ok { + models = append(models, model) + continue + } + models = append(models, openai.Model{ + ID: modelID, + Object: "model", + Created: 1704067200, + OwnedBy: "openai", + Type: "model", + DisplayName: modelID, + }) + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) +} + +func filterModelsByCustomList(availableModels, fallbackModels, selectedModels []string) []string { + if len(selectedModels) == 0 { + return availableModels + } + source := availableModels + if len(source) == 0 { + source = fallbackModels + } + if len(source) == 0 { + return nil + } + + allowed := make([]string, 0, len(source)) + for _, model := range source { + model = strings.TrimSpace(model) + if model != "" { + allowed = append(allowed, model) + } + } + + seen := make(map[string]struct{}, len(selectedModels)) + filtered := make([]string, 0, len(selectedModels)) + for _, model := range selectedModels { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if !customModelsListAllowsModel(allowed, model) { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + filtered = append(filtered, model) + } + return filtered +} + +func customModelsListAllowsModel(availablePatterns []string, model string) bool { + for _, pattern := range availablePatterns { + if pattern == model { + return true + } + if strings.HasSuffix(pattern, "*") && strings.HasPrefix(model, strings.TrimSuffix(pattern, "*")) { + return true + } + } + return false +} + +func defaultModelIDsForPlatform(platform string) []string { + switch platform { + case service.PlatformOpenAI: + return openai.DefaultModelIDs() + case service.PlatformGemini: + ids := make([]string, 0, len(geminicli.DefaultModels)) + for _, model := range geminicli.DefaultModels { + ids = append(ids, model.ID) + } + return ids + case service.PlatformAntigravity: + models := antigravity.DefaultModels() + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids + default: + ids := make([]string, 0, len(claude.DefaultModels)) + for _, model := range claude.DefaultModels { + ids = append(ids, model.ID) + } + return ids + } +} + // AntigravityModels 返回 Antigravity 支持的全部模型 // GET /antigravity/models func (h *GatewayHandler) AntigravityModels(c *gin.Context) { @@ -1502,6 +1628,14 @@ func (h *GatewayHandler) sendFailoverKeepalivePing(c *gin.Context, streamStarted // handleStreamingAwareError handles errors that may occur after streaming has started func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { + // /v1/responses 的严格 SDK(Codex CLI)要求终止事件必须属于 + // response.completed/failed/incomplete/cancelled 集合。 + // Anthropic-backed Responses 路径同样会因为通用 error 帧被拒。 + if inboundIsResponses(c) { + if writeResponsesFailedSSE(c, errType, message) { + return + } + } // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { @@ -1520,10 +1654,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e } // ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +// Writer 已被写过时(ping 已 flush)走 streamStarted 分支, +// 让 handleStreamingAwareError 通过 SSE 发协议合规的终止事件, +// 否则下游收到的就是 silent EOF。 func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { - if c == nil || c.Writer == nil || c.Writer.Written() { + if c == nil || c.Writer == nil { return false } + if c.Writer.Written() { + streamStarted = true + } h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) return true } @@ -1650,7 +1790,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { c.Header("Retry-After", strconv.Itoa(retryAfter)) @@ -1898,6 +2038,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter c.JSON(http.StatusOK, response) } +// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算 +// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。 +func extractQuotaResetSeconds(err error) int { + const fallback = 60 + appErr := pkgerrors.FromError(err) + if appErr == nil { + return fallback + } + raw, ok := appErr.Metadata["window_resets_at"] + if !ok || raw == "" { + return fallback + } + resetAt, parseErr := time.Parse(time.RFC3339, raw) + if parseErr != nil { + logger.L().With( + zap.String("component", "handler.gateway.billing"), + zap.String("raw", raw), + zap.Error(parseErr), + ).Warn("quota.invalid_window_resets_at_format") + return fallback + } + secs := time.Until(resetAt).Seconds() + if secs <= 0 { + // reset 时间已过:cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避, + // 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。 + return fallback + } + return int(math.Ceil(secs)) +} + func billingErrorDetails(err error) (status int, code, message string, retryAfter int) { if errors.Is(err, service.ErrBillingServiceUnavailable) { msg := pkgerrors.Message(err) @@ -1925,6 +2095,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte retrySeconds := 60 - int(time.Now().Unix()%60) return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds } + if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) || + errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) || + errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) { + // 与 RPM 超限一致映射 429 + Retry-After,让 SDK 自动退避(而非 403 直接失败)。 + // 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。 + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err) + } msg := pkgerrors.Message(err) if msg == "" { logger.L().With( diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go index e8a88802..eaf48d1e 100644 --- a/backend/internal/handler/gateway_handler_billing_error_test.go +++ b/backend/internal/handler/gateway_handler_billing_error_test.go @@ -1,8 +1,10 @@ package handler import ( + "errors" "net/http" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" @@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) { require.Equal(t, "billing_error", code) require.NotEmpty(t, msg) } + +func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) { + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339), + }) + got := extractQuotaResetSeconds(err) + if got < 10 || got > 11 { + t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got) + } +} + +func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) { + if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 { + t.Errorf("T20: got %d, want 60 fallback", got) + } +} + +func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) { + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": "not-a-time", + }) + if got := extractQuotaResetSeconds(err); got != 60 { + t.Errorf("T21: got %d, want 60 fallback", got) + } +} + +func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) { + // 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s: + // 1 秒会导致客户端立即重试仍触发限额的退避循环; + // 60s 让客户端按常规节奏退避,cache/DB 自愈期间不会反复打抖。 + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339), + }) + if got := extractQuotaResetSeconds(err); got != 60 { + t.Errorf("T22: got %d, want 60 (fallback on past reset)", got) + } +} + +func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) { + // quota 超限映射 429 + Retry-After(RFC 6585 / 与 RPM 一致), + // 让 SDK(OpenAI 兼容客户端等)能按 Retry-After 自动退避。 + // 旧实现用 403 导致客户端不退避直接报错。 + // 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。 + cases := []struct { + name string + err error + }{ + {"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + {"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + {"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + status, code, _, retryAfter := billingErrorDetails(tc.err) + if status != http.StatusTooManyRequests { + t.Errorf("status = %d, want 429", status) + } + if code != "rate_limit_exceeded" { + t.Errorf("code = %q, want rate_limit_exceeded", code) + } + if retryAfter < 3599 || retryAfter > 3601 { + t.Errorf("retryAfter = %d, want ~3600", retryAfter) + } + }) + } +} diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 9a091fcd..acbdc261 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { } // 2. Re-check billing - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go index 4fce5ec1..fe9e2ebf 100644 --- a/backend/internal/handler/gateway_handler_error_fallback_test.go +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -33,7 +33,9 @@ func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testi assert.Equal(t, "Upstream request failed", errorObj["message"]) } -func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { +// Writer 已写后 ensureForwardErrorResponse 必须把错误以 SSE 形式追加, +// 而不是 silent EOF。非 /responses 路径走 legacy data:{"type":"error"} 分支。 +func TestGatewayEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -43,7 +45,27 @@ func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *tes h := &GatewayHandler{} wrote := h.ensureForwardErrorResponse(c, false) - require.False(t, wrote) + require.True(t, wrote) require.Equal(t, http.StatusTeapot, w.Code) - assert.Equal(t, "already written", w.Body.String()) + assert.Contains(t, w.Body.String(), "already written") + assert.Contains(t, w.Body.String(), `data: {"type":"error"`) +} + +// case B 回归:Anthropic-backed /responses,Writer 已被写过时 +// ensureForwardErrorResponse 仍要发 response.failed。 +func TestGatewayEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil) + _, _ = c.Writer.WriteString(":\n\n") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + body := w.Body.String() + assert.Contains(t, body, ":\n\n") + assert.Contains(t, body, "event: response.failed\n") + assert.Contains(t, body, `"type":"response.failed"`) } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e1a5b723..6a083f31 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { } // 2. Re-check billing - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 57554cf9..09b20722 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // channelService nil, // resolver nil, // balanceNotifyService + nil, // userPlatformQuotaRepo ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go index 78b07a1a..3206fb4a 100644 --- a/backend/internal/handler/gateway_models_test.go +++ b/backend/internal/handler/gateway_models_test.go @@ -25,7 +25,11 @@ type gatewayModelsResponseForTest struct { } type gatewayModelItemForTest struct { - ID string `json:"id"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + CreatedAt string `json:"created_at"` } func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { @@ -43,7 +47,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand gatewayService: service.NewGatewayService( repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ), } } @@ -127,6 +131,267 @@ func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) } +func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(22) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.5": "gpt-5.5", + "gpt-5.4": "gpt-5.4", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: false, + Models: []string{"gpt-5.5"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(23) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + "gpt-5.5": "gpt-5.5", + "legacy-gpt-2024": "legacy-gpt-2024", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(26) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-*": "claude-sonnet-4-6", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformAnthropic, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"claude-sonnet-4-6"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(24) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformOpenAI, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Empty(t, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(25) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformOpenAI}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) +} + +func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(27) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformOpenAI}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + ModelsListConfig: service.GroupModelsListConfig{ + Enabled: true, + Models: []string{"gpt-5.5", "gpt-5.4"}, + }, + }, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data)) + require.Equal(t, "model", got.Data[0].Object) + require.NotZero(t, got.Data[0].Created) + require.Equal(t, "openai", got.Data[0].OwnedBy) + require.Empty(t, got.Data[0].CreatedAt) +} + func modelIDsForTest(models []gatewayModelItemForTest) []string { ids := make([]string, 0, len(models)) for _, model := range models { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 665c0677..27ea4404 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 2) billing eligibility check (after wait) - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) status, _, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(body) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/image_concurrency_limiter_test.go b/backend/internal/handler/image_concurrency_limiter_test.go index 20147f16..723c26d0 100644 --- a/backend/internal/handler/image_concurrency_limiter_test.go +++ b/backend/internal/handler/image_concurrency_limiter_test.go @@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t * h := &OpenAIGatewayHandler{ gatewayService: &service.OpenAIGatewayService{}, - billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}), + billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil), apiKeyService: &service.APIKeyService{}, concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 4d523dba..17f0d47e 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go new file mode 100644 index 00000000..bbb67044 --- /dev/null +++ b/backend/internal/handler/openai_embeddings.go @@ -0,0 +1,253 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "strconv" + "strings" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// Embeddings handles the OpenAI-compatible Embeddings API. +// POST /v1/embeddings +func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { + streamStarted := false + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.embeddings", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqLog = reqLog.With(zap.String("model", reqModel)) + setOpsRequestContext(c, reqModel, false) + setOpsEndpointContext(c, "", int16(service.RequestTypeSync)) + + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { + reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err)) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + h.errorResponse(c, status, code, message) + return + } + + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError + switchCount := 0 + maxAccountSwitches := h.maxAccountSwitches + if maxAccountSwitches <= 0 { + maxAccountSwitches = 3 + } + routingStart := time.Now() + + for { + selection, _, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + "", + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportHTTPSSE, + false, + ) + if err != nil { + reqLog.Warn("openai_embeddings.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") + return + } + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, false) + } else { + h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed") + } + return + } + if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + account := selection.Account + if account.Type != service.AccountTypeAPIKey { + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + failedAccountIDs[account.ID] = struct{}{} + continue + } + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog) + if !accountAcquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + writerSizeBeforeForward := c.Writer.Size() + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "") + }() + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, false) + return + } + switchCount++ + reqLog.Warn("openai_embeddings.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if c.Writer.Size() == writerSizeBeforeForward { + h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + } + reqLog.Warn("openai_embeddings.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.embeddings"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_embeddings.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_embeddings.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 9c5560f5..a51eee86 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // 2. Re-check billing eligibility after wait - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { var currentUserRelease func() var currentAccountRelease func() - releaseTurnSlots := func() { + releaseAccountSlot := func() { if currentAccountRelease != nil { currentAccountRelease() currentAccountRelease = nil } + } + releaseTurnSlots := func() { + releaseAccountSlot() if currentUserRelease != nil { currentUserRelease() currentUserRelease = nil @@ -1233,9 +1236,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { return } currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + ensureUserSlotHeld := func() bool { + if currentUserRelease != nil { + return true + } + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return false + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return false + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + return true + } subscription, _ := middleware2.GetSubscriptionFromContext(c) - if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") return @@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { firstMessage, openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), ) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( - ctx, - apiKey.GroupID, - previousResponseID, - sessionHash, - reqModel, - nil, - service.OpenAIUpstreamTransportResponsesWebsocketV2, - false, - ) - if err != nil { - reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") - return - } - if selection == nil || selection.Account == nil { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") - return - } + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError - account := selection.Account - accountMaxConcurrency := account.Concurrency - if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { - accountMaxConcurrency = selection.WaitPlan.MaxConcurrency - } - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") - return - } - fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + for { + reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( ctx, - account.ID, - selection.WaitPlan.MaxConcurrency, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + false, ) if err != nil { - reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + reqLog.Warn("openai.websocket_account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if lastFailoverErr != nil { + closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr) + } else { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + } return } - if !fastAcquired { - closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + if selection == nil || selection.Account == nil { + if lastFailoverErr != nil { + closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr) + } else { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + } return } - accountReleaseFunc = fastReleaseFunc - } - currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) - if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - token, _, err := h.gatewayService.GetAccessToken(ctx, account) - if err != nil { - reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") - return - } - - reqLog.Debug("openai.websocket_account_selected", - zap.Int64("account_id", account.ID), - zap.String("account_name", account.Name), - zap.String("schedule_layer", scheduleDecision.Layer), - zap.Int("candidate_count", scheduleDecision.CandidateCount), - ) - - hooks := &service.OpenAIWSIngressHooks{ - InitialRequestModel: reqModel, - BeforeRequest: func(turn int, payload []byte, originalModel string) error { - if turn == 1 { - return nil - } - if !gjson.ValidBytes(payload) { - return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) - } - model := strings.TrimSpace(originalModel) - if model == "" { - model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - } - if model == "" { - model = reqModel - } - if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { - writeContentModerationWSError(ctx, wsConn, decision) - return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) - } - return nil - }, - BeforeTurn: func(turn int) error { - if turn == 1 { - return nil - } - // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 - releaseTurnSlots() - // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 - userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) - if err != nil { - return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) - } - if !userAcquired { - return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) - } - accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) - if err != nil { - if userReleaseFunc != nil { - userReleaseFunc() - } - return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) - } - if !accountAcquired { - if userReleaseFunc != nil { - userReleaseFunc() - } - return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) - } - currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) - currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) - return nil - }, - AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { - releaseTurnSlots() - if turnErr != nil { - if result == nil || result.ImageCount <= 0 { - return - } - reqLog.Warn("openai.websocket_partial_error_with_image_result", - zap.Int64("account_id", account.ID), - zap.Int("image_count", result.ImageCount), - zap.Error(turnErr), - ) - } - if result == nil { + account := selection.Account + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") return } - if account.Type == service.AccountTypeOAuth { - h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) - inboundEndpoint := GetInboundEndpoint(c) - upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { - if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: inboundEndpoint, - UpstreamEndpoint: upstreamEndpoint, - UserAgent: userAgent, - IPAddress: clientIP, - RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), - APIKeyService: h.apiKeyService, - ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), - }); err != nil { - reqLog.Error("openai.websocket_record_usage_failed", - zap.Int64("account_id", account.ID), - zap.String("request_id", result.RequestID), - zap.Error(err), - ) - } - }) - }, - } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } - // 应用渠道模型映射到 WebSocket 首条消息 - wsFirstMessage := firstMessage - if channelMappingWS.Mapped { - wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) - } - - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - closeStatus, closeReason := summarizeWSCloseErrorForLog(err) - reqLog.Warn("openai.websocket_proxy_failed", - zap.Int64("account_id", account.ID), - zap.Error(err), - zap.String("close_status", closeStatus), - zap.String("close_reason", closeReason), - ) - var closeErr *service.OpenAIWSClientCloseError - if errors.As(err, &closeErr) { - closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") return } - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + ) + + hooks := &service.OpenAIWSIngressHooks{ + InitialRequestModel: reqModel, + BeforeRequest: func(turn int, payload []byte, originalModel string) error { + if turn == 1 { + return nil + } + if !gjson.ValidBytes(payload) { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + model := strings.TrimSpace(originalModel) + if model == "" { + model = strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + } + if model == "" { + model = reqModel + } + if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked { + writeContentModerationWSError(ctx, wsConn, decision) + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil) + } + return nil + }, + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil { + if result == nil || result.ImageCount <= 0 { + return + } + reqLog.Warn("openai.websocket_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(turnErr), + ) + } + if result == nil { + return + } + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + // 应用渠道模型映射到 WebSocket 首条消息 + wsFirstMessage := firstMessage + if channelMappingWS.Mapped { + wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + releaseAccountSlot() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + closeOpenAIWSFailoverExhausted(wsConn, failoverErr) + return + } + switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + closeOpenAIWSFailoverExhausted(wsConn, failoverErr) + return + } + h.gatewayService.RecordOpenAIAccountSwitch() + reqLog.Warn("openai.websocket_upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + if !ensureUserSlotHeld() { + return + } + continue + } + + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) return } - reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) + } func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { @@ -1691,6 +1760,15 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st // handleStreamingAwareError handles errors that may occur after streaming has started func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { + // /v1/responses 的严格 SDK(Codex CLI)要求终止事件必须属于 + // response.completed/failed/incomplete/cancelled 集合。 + // 通用 `event: error` 帧不被识别为终止事件,会导致 + // "stream closed before response.completed"。 + if inboundIsResponses(c) { + if writeResponsesFailedSSE(c, errType, message) { + return + } + } // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { @@ -1710,9 +1788,17 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { - if c == nil || c.Writer == nil || c.Writer.Written() { + if c == nil || c.Writer == nil { return false } + // 旧实现在 Writer.Written 时直接 return false,导致 ping 已 flush 之后的 + // 上游错误(http2 timeout、连接中断等)完全无法把错误传给客户端—— + // HTTP 200 已锁死,TCP 直接 EOF,Codex CLI 报 "stream closed before response.completed"。 + // 这里改成:Writer 已写过时强制走 streamStarted 分支,让 + // handleStreamingAwareError 通过 SSE 发协议合规的 response.failed。 + if c.Writer.Written() { + streamStarted = true + } h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) return true } @@ -1783,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s _ = conn.CloseNow() } +func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) { + if failoverErr == nil { + closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + switch failoverErr.StatusCode { + case http.StatusTooManyRequests: + closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later") + case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable") + case http.StatusUnauthorized, http.StatusForbidden: + closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed") + default: + closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed") + } +} + func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) { if conn == nil || decision == nil { return diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 6bddbce9..7de30e9c 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -174,7 +175,11 @@ func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testin assert.Equal(t, "Upstream request failed", errorObj["message"]) } -func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { +// Writer 已写后 ensureForwardErrorResponse 必须仍然把错误信息以 SSE +// 形式追加给客户端(streamStarted 强制 true)。 +// 这是 case B 修复:旧实现遇到 Writer.Written 直接 return false, +// 客户端只能拿到 silent EOF;Codex CLI 报 "stream closed before response.completed"。 +func TestOpenAIEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -184,9 +189,34 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test h := &OpenAIGatewayHandler{} wrote := h.ensureForwardErrorResponse(c, false) - require.False(t, wrote) + require.True(t, wrote, "must attempt to communicate the failure to the client via SSE") + // 状态码改不了(headers 已 flush),但 body 应该追加 SSE 错误事件。 require.Equal(t, http.StatusTeapot, w.Code) - assert.Equal(t, "already written", w.Body.String()) + assert.Contains(t, w.Body.String(), "already written") + // 非 /responses 路径走 legacy event: error 分支。 + assert.Contains(t, w.Body.String(), "event: error\n") +} + +// case B 回归测试:/responses 路径,Writer 已被写过(模拟 ping flushed), +// ensureForwardErrorResponse 必须发 response.failed,让 Codex 收到合规终止事件。 +func TestOpenAIEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil) + // 模拟 ping 已 flush 的状态:Writer 已写过 1 个字节 + _, _ = c.Writer.WriteString(":\n\n") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + body := w.Body.String() + assert.Contains(t, body, ":\n\n", "earlier ping bytes preserved") + assert.Contains(t, body, "event: response.failed\n", "appended a Responses terminal event") + assert.Contains(t, body, `"type":"response.failed"`) + assert.Contains(t, body, `"code":"upstream_error"`) + assert.Contains(t, body, "Upstream request failed") } func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { @@ -266,7 +296,9 @@ func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { assert.Equal(t, "", w.Body.String()) } -func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { +// Panic 在已 flush 的 /v1/responses 流中:状态码无法改(已 written), +// 但 body 应追加 response.failed 让客户端识别为合规截断而不是 silent EOF。 +func TestOpenAIRecoverResponsesPanic_AppendsResponseFailedAfterWritten(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -284,7 +316,9 @@ func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T }) require.Equal(t, http.StatusTeapot, w.Code) - assert.Equal(t, "already written", w.Body.String()) + body := w.Body.String() + assert.Contains(t, body, "already written") + assert.Contains(t, body, "event: response.failed\n") } func TestOpenAIMissingResponsesDependencies(t *testing.T) { @@ -707,16 +741,31 @@ func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key st } type contentModerationHandlerTestRepo struct { + mu sync.Mutex logs []service.ContentModerationLog } func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error { if log != nil { + r.mu.Lock() + defer r.mu.Unlock() r.logs = append(r.logs, *log) } return nil } +func (r *contentModerationHandlerTestRepo) resetLogs() { + r.mu.Lock() + defer r.mu.Unlock() + r.logs = nil +} + +func (r *contentModerationHandlerTestRepo) logSnapshot() []service.ContentModerationLog { + r.mu.Lock() + defer r.mu.Unlock() + return append([]service.ContentModerationLog(nil), r.logs...) +} + func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -775,7 +824,10 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T }) require.NoError(t, err) require.True(t, decision.Blocked) - repo.logs = nil + require.Eventually(t, func() bool { + return len(repo.logSnapshot()) == 1 + }, time.Second, 10*time.Millisecond) + repo.resetLogs() h := &OpenAIGatewayHandler{ gatewayService: &service.OpenAIGatewayService{}, billingCacheService: &service.BillingCacheService{}, @@ -815,10 +867,11 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) require.Contains(t, closeErr.Reason, "内容审计测试阻断") } - require.Len(t, repo.logs, 1) - require.True(t, repo.logs[0].Flagged) - require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action) - require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt) + logs := repo.logSnapshot() + require.Len(t, logs, 1) + require.True(t, logs[0].Flagged) + require.Equal(t, service.ContentModerationActionBlock, logs[0].Action) + require.Equal(t, "bad prompt", logs[0].InputExcerpt) } func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) { @@ -1042,6 +1095,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in return &account, nil } +type openAIWSFailoverHandlerAccountRepoStub struct { + service.AccountRepository + accounts []service.Account + rateLimitedIDs []int64 +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + out := make([]service.Account, 0, len(s.accounts)) + for _, account := range s.accounts { + if account.Platform == platform && account.IsSchedulable() { + out = append(out, account) + } + } + return out, nil +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return s.ListSchedulableByPlatform(ctx, platform) +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) { + for _, account := range s.accounts { + if account.ID == id { + acc := account + return &acc, nil + } + } + return nil, nil +} + +func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + s.rateLimitedIDs = append(s.rateLimitedIDs, id) + for i := range s.accounts { + if s.accounts[i].ID == id { + reset := resetAt + s.accounts[i].RateLimitResetAt = &reset + break + } + } + return nil +} + type openAIWSUsageHandlerUsageLogRepoStub struct { service.UsageLogRepository created chan *service.UsageLog @@ -1074,6 +1173,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont return out, nil } +func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + firstHitCh := make(chan []byte, 1) + secondHitCh := make(chan []byte, 1) + + firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + return + } + defer func() { _ = conn.CloseNow() }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + _, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr == nil { + firstHitCh <- payload + } + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + _ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`)) + cancelWrite() + })) + defer firstUpstream.Close() + + secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover}) + if err != nil { + return + } + defer func() { _ = conn.CloseNow() }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + _, payload, readErr := conn.Read(readCtx) + cancelRead() + if readErr == nil { + secondHitCh <- payload + } + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + _ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`)) + cancelWrite() + _ = conn.Close(coderws.StatusNormalClosure, "done") + })) + defer secondUpstream.Close() + + groupID := int64(4202) + accounts := []service.Account{ + { + ID: 9902, + Name: "openai-ws-rate-limited", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "sk-first", + "base_url": firstUpstream.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + }, + { + ID: 9903, + Name: "openai-ws-healthy", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 2, + Credentials: map[string]any{ + "api_key": "sk-second", + "base_url": secondUpstream.URL, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + "openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + }, + }, + } + + cfg := &config.Config{} + cfg.RunMode = config.RunModeSimple + cfg.Default.RateMultiplier = 1 + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + cfg.Gateway.MaxAccountSwitches = 3 + + accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts} + rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) + gatewaySvc := service.NewOpenAIGatewayService( + accountRepo, + nil, + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + service.NewBillingService(cfg, nil), + rateLimitSvc, + billingCacheSvc, + nil, + &service.DeferredService{}, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + h := &OpenAIGatewayHandler{ + gatewayService: gatewaySvc, + billingCacheService: billingCacheSvc, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + maxAccountSwitches: 3, + } + + apiKey := &service.APIKey{ + ID: 1802, + GroupID: &groupID, + User: &service.User{ID: 1702, Status: service.StatusActive}, + Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1}) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + handlerServer := httptest.NewServer(router) + defer handlerServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses", + &coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover}, + ) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second) + _, event, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String()) + + select { + case <-firstHitCh: + case <-time.After(3 * time.Second): + t.Fatal("等待第一个上游收到首帧超时") + } + select { + case <-secondHitCh: + case <-time.After(3 * time.Second): + t.Fatal("等待第二个上游收到重放首帧超时") + } + require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs) +} + func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult { t.Helper() gin.SetMode(gin.TestMode) @@ -1168,7 +1462,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU }, nil, nil, nil) } - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) gatewaySvc := service.NewOpenAIGatewayService( accountRepo, usageRepo, @@ -1190,6 +1484,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU channelSvc, nil, nil, + nil, // userPlatformQuotaRepo ) cache := &concurrencyCacheMock{ diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index e6c37272..bbb08014 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index fcf88f6c..002b48e9 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -53,6 +53,8 @@ const ( opsCodeUserNotFound = "USER_NOT_FOUND" opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED" opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated" + opsCodeGroupDeleted = "GROUP_DELETED" + opsCodeGroupDisabled = "GROUP_DISABLED" ) const ( @@ -1012,6 +1014,8 @@ func parseOpsErrorResponse(body []byte) parsedOpsError { var code string if v, ok := errObj["code"]; ok { switch n := v.(type) { + case string: + code = strings.TrimSpace(n) case float64: code = strconvItoa(int(n)) case int: @@ -1190,14 +1194,19 @@ func isOpsClientAuthError(code string, msg string) bool { opsCodeAPIKeyExpired, opsCodeAPIKeyDisabled, opsCodeUserNotFound, - opsCodeUserInactive: + opsCodeUserInactive, + opsCodeGroupDeleted, + opsCodeGroupDisabled: return true } return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required") || strings.Contains(msg, "api key is disabled") || strings.Contains(msg, "user associated with api key not found") || - strings.Contains(msg, "user account is not active") + strings.Contains(msg, "user account is not active") || + strings.Contains(msg, "api key 所属分组已删除") || + strings.Contains(msg, "api key 所属分组已停用") || + strings.Contains(msg, "api key is not assigned to any group") } func isOpsLocalBusinessLimitError(code string, msg string) bool { @@ -1213,6 +1222,7 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool { return strings.Contains(msg, "api key in query parameter is deprecated") || strings.Contains(msg, "query parameter api_key is deprecated") || strings.Contains(msg, "no active subscription found for this group") || + strings.Contains(msg, "subscription is invalid or expired") || strings.Contains(msg, opsErrInsufficientBalance) || strings.Contains(msg, "insufficient account balance") || strings.Contains(msg, "api key group platform is not gemini") || @@ -1223,7 +1233,22 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool { strings.Contains(msg, "daily usage limit exceeded") || strings.Contains(msg, "weekly usage limit exceeded") || strings.Contains(msg, "monthly usage limit exceeded") || - strings.Contains(msg, "requests-per-minute limit exceeded") + strings.Contains(msg, "usage quota exhausted for this platform") || + strings.Contains(msg, "requests-per-minute limit exceeded") || + strings.Contains(msg, "too many pending requests") || + strings.Contains(msg, "concurrency limit exceeded") || + strings.Contains(msg, "image generation concurrency limit exceeded") || + strings.Contains(msg, "this group is restricted to claude code clients") || + strings.Contains(msg, "this group does not allow /v1/messages dispatch") || + strings.Contains(msg, "image generation is not enabled for this group") || + strings.Contains(msg, "token counting is not supported for this platform") || + strings.Contains(msg, "images api is not supported for this platform") || + (strings.Contains(msg, "model ") && strings.Contains(msg, " not in whitelist")) || + (strings.Contains(msg, "beta feature ") && strings.Contains(msg, " is not allowed")) || + (strings.Contains(msg, "openai service_tier=") && strings.Contains(msg, " is not allowed for model")) || + strings.Contains(msg, "this account only allows codex official clients") || + strings.Contains(msg, "openai wsv1 is temporarily unsupported") || + strings.Contains(msg, "openai codex passthrough requires a non-empty instructions field") } func hasOpsUpstreamErrorContext(c *gin.Context) bool { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 81f12b0c..d4e1177e 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -288,6 +288,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) { code: "USER_INACTIVE", status: http.StatusUnauthorized, }, + { + name: "deleted local API key group", + errType: "api_error", + message: "API Key 所属分组已删除", + code: "GROUP_DELETED", + status: http.StatusForbidden, + }, + { + name: "disabled local API key group", + errType: "api_error", + message: "API Key 所属分组已停用", + code: "GROUP_DISABLED", + status: http.StatusForbidden, + }, + { + name: "google deleted API key group message without semantic code", + errType: "api_error", + message: "API Key 所属分组已删除", + code: "403", + status: http.StatusForbidden, + }, + { + name: "anthropic unassigned API key group", + errType: "permission_error", + message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.", + code: "", + status: http.StatusForbidden, + }, { name: "google invalid API key", errType: "api_error", @@ -389,6 +417,15 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) { wantErrType: "api_error", wantPhase: "request", }, + { + name: "gateway subscription invalid cache recheck", + errType: "billing_error", + message: "subscription is invalid or expired", + code: "billing_error", + status: http.StatusForbidden, + wantErrType: "billing_error", + wantPhase: "request", + }, { name: "google insufficient account balance", errType: "api_error", @@ -443,6 +480,132 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) { wantErrType: "api_error", wantPhase: "request", }, + { + name: "user platform daily quota exhausted", + errType: "api_error", + message: "Daily usage quota exhausted for this platform.", + code: "rate_limit_exceeded", + status: http.StatusTooManyRequests, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "local pending queue limit", + errType: "rate_limit_error", + message: "Too many pending requests, please retry later", + code: "", + status: http.StatusTooManyRequests, + wantErrType: "rate_limit_error", + wantPhase: "request", + }, + { + name: "local concurrency limit", + errType: "rate_limit_error", + message: "Concurrency limit exceeded for user, please retry later", + code: "", + status: http.StatusTooManyRequests, + wantErrType: "rate_limit_error", + wantPhase: "request", + }, + { + name: "group claude code only feature gate", + errType: "permission_error", + message: "This group is restricted to Claude Code clients (/v1/messages only)", + code: "", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "group image generation feature gate", + errType: "permission_error", + message: "Image generation is not enabled for this group", + code: "", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "route token counting platform unsupported", + errType: "not_found_error", + message: "Token counting is not supported for this platform", + code: "", + status: http.StatusNotFound, + wantErrType: "not_found_error", + wantPhase: "request", + }, + { + name: "route images API platform unsupported", + errType: "not_found_error", + message: "Images API is not supported for this platform", + code: "", + status: http.StatusNotFound, + wantErrType: "not_found_error", + wantPhase: "request", + }, + { + name: "antigravity model whitelist feature gate", + errType: "permission_error", + message: "model claude-3-5-sonnet not in whitelist", + code: "", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "google antigravity model whitelist feature gate", + errType: "api_error", + message: "model gemini-2.5-pro not in whitelist", + code: "403", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "claude beta policy block", + errType: "invalid_request_error", + message: "beta feature interleaved-thinking-2025-05-14 is not allowed", + code: "", + status: http.StatusBadRequest, + wantErrType: "invalid_request_error", + wantPhase: "request", + }, + { + name: "openai fast policy block", + errType: "permission_error", + message: "openai service_tier=priority is not allowed for model gpt-5.5", + code: "", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "codex official client policy block", + errType: "forbidden_error", + message: "This account only allows Codex official clients", + code: "", + status: http.StatusForbidden, + wantErrType: "forbidden_error", + wantPhase: "request", + }, + { + name: "openai wsv1 unsupported feature gate", + errType: "invalid_request_error", + message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + code: "", + status: http.StatusBadRequest, + wantErrType: "invalid_request_error", + wantPhase: "request", + }, + { + name: "openai passthrough instructions policy block", + errType: "forbidden_error", + message: "OpenAI codex passthrough requires a non-empty instructions field", + code: "", + status: http.StatusForbidden, + wantErrType: "forbidden_error", + wantPhase: "request", + }, } for _, tt := range tests { @@ -479,6 +642,22 @@ func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) { require.Equal(t, "client_request", errorSource) } +func TestClassifyOpsClientBusinessLimitedMarkerExcludesCustomPolicyDenialFromSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied) + + errType := normalizeOpsErrorType("invalid_request_error", "") + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "custom admin policy message", "", http.StatusBadRequest) + + require.Equal(t, "invalid_request_error", errType) + require.Equal(t, "auth", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "client", errorOwner) + require.Equal(t, "client_request", errorSource) +} + func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -583,6 +762,78 @@ func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) { code: "API_KEY_QUOTA_EXHAUSTED", status: http.StatusTooManyRequests, }, + { + name: "provider deleted group shaped error", + message: "API Key 所属分组已删除", + code: "GROUP_DELETED", + status: http.StatusForbidden, + }, + { + name: "provider unassigned group shaped error", + message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.", + code: "403", + status: http.StatusForbidden, + }, + { + name: "provider local quota shaped error", + message: "Daily usage quota exhausted for this platform.", + code: "rate_limit_exceeded", + status: http.StatusTooManyRequests, + }, + { + name: "provider feature gate shaped error", + message: "Image generation is not enabled for this group", + code: "403", + status: http.StatusForbidden, + }, + { + name: "provider token counting unsupported shaped error", + message: "Token counting is not supported for this platform", + code: "404", + status: http.StatusNotFound, + }, + { + name: "provider image API unsupported shaped error", + message: "Images API is not supported for this platform", + code: "404", + status: http.StatusNotFound, + }, + { + name: "provider antigravity whitelist shaped error", + message: "model claude-3-5-sonnet not in whitelist", + code: "403", + status: http.StatusForbidden, + }, + { + name: "provider beta policy shaped error", + message: "beta feature interleaved-thinking-2025-05-14 is not allowed", + code: "400", + status: http.StatusBadRequest, + }, + { + name: "provider openai fast policy shaped error", + message: "openai service_tier=priority is not allowed for model gpt-5.5", + code: "403", + status: http.StatusForbidden, + }, + { + name: "provider codex client policy shaped error", + message: "This account only allows Codex official clients", + code: "403", + status: http.StatusForbidden, + }, + { + name: "provider wsv1 unsupported shaped error", + message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + code: "400", + status: http.StatusBadRequest, + }, + { + name: "provider passthrough instructions shaped error", + message: "OpenAI codex passthrough requires a non-empty instructions field", + code: "403", + status: http.StatusForbidden, + }, } for _, tt := range tests { @@ -628,6 +879,14 @@ func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) { require.Equal(t, "upstream_http", errorSource) } +func TestParseOpsErrorResponsePreservesNestedStringCode(t *testing.T) { + parsed := parseOpsErrorResponse([]byte(`{"error":{"type":"permission_error","code":"GROUP_DELETED","message":"API Key 所属分组已删除"}}`)) + + require.Equal(t, "permission_error", parsed.ErrorType) + require.Equal(t, "GROUP_DELETED", parsed.Code) + require.Equal(t, "API Key 所属分组已删除", parsed.Message) +} + func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/handler/quotaview/helpers.go b/backend/internal/handler/quotaview/helpers.go new file mode 100644 index 00000000..ff4e44a1 --- /dev/null +++ b/backend/internal/handler/quotaview/helpers.go @@ -0,0 +1,104 @@ +// Package quotaview provides shared quota response helpers for user and admin handlers. +// Extracted to avoid import cycles between handler and handler/admin packages. +package quotaview + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB)。 +// includeWindowStart=true 时输出 *_window_start 字段(admin 视角调试用) +func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any { + daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart) + weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart) + monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart) + out := map[string]any{ + "platform": r.Platform, + "daily_usage_usd": daily.usage, + "daily_limit_usd": daily.limit, + "daily_window_resets_at": daily.resetsAt, + "weekly_usage_usd": weekly.usage, + "weekly_limit_usd": weekly.limit, + "weekly_window_resets_at": weekly.resetsAt, + "monthly_usage_usd": monthly.usage, + "monthly_limit_usd": monthly.limit, + "monthly_window_resets_at": monthly.resetsAt, + } + if includeWindowStart { + out["daily_window_start"] = daily.windowStart + out["weekly_window_start"] = weekly.windowStart + out["monthly_window_start"] = monthly.windowStart + } + return out +} + +type windowSlice struct { + usage float64 + limit *float64 + resetsAt *string + windowStart *string +} + +func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice { + out := windowSlice{usage: usage, limit: limit} + if expired { + out.usage = 0 + out.resetsAt = nil + } else if start != nil { + s := nextReset.Format(time.RFC3339) + out.resetsAt = &s + } + if includeStart && start != nil { + s := start.Format(time.RFC3339) + out.windowStart = &s + } + return out +} + +// NeedsDailyReset 判断日窗口是否已过期:start 早于「全局时区当天 0 点」即过期。 +// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。 +func NeedsDailyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return start.Before(timezone.StartOfDay(now)) +} + +func NeedsWeeklyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return start.Before(timezone.StartOfWeek(now)) +} + +// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。 +func NeedsMonthlyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return now.Sub(*start) >= 30*24*time.Hour +} + +func nextDailyResetTime(now time.Time) time.Time { + return timezone.StartOfDay(now).AddDate(0, 0, 1) +} + +func nextWeeklyResetTime(now time.Time) time.Time { + return timezone.StartOfWeek(now).AddDate(0, 0, 7) +} + +// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。 +// 语义: +// - start != nil → 返回 start + 30d(与 billing_cache_service.nextMonthlyResetFrom 一致) +// - start == nil → 退化为 now + 30d(保留旧行为,避免 nil 崩溃) +// +// 导出(首字母大写)以允许测试直接调用。 +func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time { + if start == nil { + return now.Add(30 * 24 * time.Hour) + } + return start.Add(30 * 24 * time.Hour) +} diff --git a/backend/internal/handler/quotaview/helpers_test.go b/backend/internal/handler/quotaview/helpers_test.go new file mode 100644 index 00000000..ca2fcc4e --- /dev/null +++ b/backend/internal/handler/quotaview/helpers_test.go @@ -0,0 +1,133 @@ +package quotaview + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TestNextMonthlyResetTimeFrom_FromStart 验证:start 已知时返回 start+30d,不随 now 漂移。 +func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) { + t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + now := t0.Add(15 * 24 * time.Hour) // t0 + 15d + want := t0.Add(30 * 24 * time.Hour) // t0 + 30d + + got := NextMonthlyResetTimeFrom(&t0, now) + if !got.Equal(want) { + t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got) + } +} + +// TestNextMonthlyResetTimeFrom_NilStart 验证:start=nil 时退化为 now+30d(不 panic)。 +func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) { + now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC) + want := now.Add(30 * 24 * time.Hour) + + got := NextMonthlyResetTimeFrom(nil, now) + if !got.Equal(want) { + t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got) + } +} + +// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证: +// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record, +// monthly_window_resets_at 始终等于 windowStart+30d,不随 now 漂移。 +func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) { + windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339) + + r := service.UserPlatformQuotaRecord{ + Platform: "openai", + MonthlyUsageUSD: 5.0, + MonthlyWindowStart: &windowStart, + } + + // 第一次调用:now = windowStart + 5d + now1 := windowStart.Add(5 * 24 * time.Hour) + out1 := LazyZeroQuotaForResponse(r, now1, false) + resetsAt1, ok1 := out1["monthly_window_resets_at"] + if !ok1 || resetsAt1 == nil { + t.Fatal("first call: monthly_window_resets_at should be set for active window") + } + s1, ok := resetsAt1.(*string) + if !ok || s1 == nil { + t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1) + } + if *s1 != wantResetsAt { + t.Errorf("first call: want %s, got %s", wantResetsAt, *s1) + } + + // 第二次调用:now = windowStart + 10d(不同 now,但 resetsAt 应不变) + now2 := windowStart.Add(10 * 24 * time.Hour) + out2 := LazyZeroQuotaForResponse(r, now2, false) + resetsAt2, ok2 := out2["monthly_window_resets_at"] + if !ok2 || resetsAt2 == nil { + t.Fatal("second call: monthly_window_resets_at should be set for active window") + } + s2, ok := resetsAt2.(*string) + if !ok || s2 == nil { + t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2) + } + if *s2 != wantResetsAt { + t.Errorf("second call: want %s, got %s", wantResetsAt, *s2) + } + + // 两次结果必须相等 + if *s1 != *s2 { + t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2) + } +} + +// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。 +func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + // now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08(北京 5/26) + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) + + // start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08(北京 5/25)→ 应判定为过期 + startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) + if !NeedsDailyReset(&startPrevBeijingDay, now) { + t.Error("上一个北京日的窗口应判定为过期") + } + + // start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08(北京 5/26 同日)→ 不应过期 + startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC) + if NeedsDailyReset(&startSameBeijingDay, now) { + t.Error("同一北京日的窗口不应判定为过期") + } +} + +// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。 +func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 + want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00 + if got := nextDailyResetTime(now); !got.Equal(want) { + t.Errorf("nextDailyResetTime = %v, want %v", got, want) + } +} + +// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。 +func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + // 北京 2026-05-26(周二)→ 下周一是 2026-06-01 + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二 + want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location()) + if got := nextWeeklyResetTime(now); !got.Equal(want) { + t.Errorf("nextWeeklyResetTime = %v, want %v", got, want) + } +} diff --git a/backend/internal/handler/stream_error_event.go b/backend/internal/handler/stream_error_event.go new file mode 100644 index 00000000..f3a33a8c --- /dev/null +++ b/backend/internal/handler/stream_error_event.go @@ -0,0 +1,165 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// responsesFailedError 对齐 OpenAI Responses 协议 error 子对象。 +type responsesFailedError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// responsesFailedBody 对齐 apicompat.makeResponsesCompletedEvent 输出的 response 子对象字段集。 +// Output 用空 slice(不是 nil)确保 marshal 为 `[]` 而非 `null`。 +type responsesFailedBody struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model,omitempty"` + Status string `json:"status"` + Output []any `json:"output"` + Error responsesFailedError `json:"error"` +} + +// responsesFailedEvent 是写入 SSE data 行的顶层结构。 +// 故意不带 sequence_number:spec 标记可选,且本函数被调用时无法可靠拿到 last seq。 +type responsesFailedEvent struct { + Type string `json:"type"` + Response responsesFailedBody `json:"response"` +} + +// writeResponsesFailedSSE emits a `response.failed` SSE event in the OpenAI +// Responses API protocol after the stream has already started. +// +// 必要性:一旦 SSE 头和任意数据(例如等待槽位时的 ping comment)已经 flush, +// HTTP 200 状态码就被固化。此后若网关需要回报错误,只能继续通过 SSE 事件传达。 +// 通用的 `event: error` 帧不是 Responses 协议规定的终止事件, +// Codex CLI 等严格 SDK 会因为没收到 `response.completed/failed/incomplete/cancelled` +// 而抛出 "stream closed before response.completed"。 +// +// 字段集对齐 apicompat.makeResponsesCompletedEvent:id/object/model/status/output/error。 +// 故意不写 sequence_number:本函数被调用时无法可靠拿到当前流的 last sequence, +// 而 OpenAI spec 将 sequence_number 设为可选;省略避免破坏单调性约束。 +// +// 返回 true 表示已尝试 SSE 写出(不论 Write 是否成功,caller 都应直接 return)。 +// 返回 false 表示 writer 不支持 Flusher,无法以 SSE 形式回报错误; +// 此时 caller 也无法回退到 JSON(HTTP 200 已固化),通常意味着连接已经损坏, +// 应当让请求处理函数 return,由上层关闭连接。 +func writeResponsesFailedSSE(c *gin.Context, errType, message string) bool { + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return false + } + + payload, err := json.Marshal(responsesFailedEvent{ + Type: "response.failed", + Response: responsesFailedBody{ + ID: synthesizeResponseID(c), + Object: "response", + Model: requestModel(c), + Status: "failed", + Output: []any{}, + Error: responsesFailedError{ + Code: mapResponsesErrorCode(errType), + Message: message, + }, + }, + }) + if err != nil { + _ = c.Error(err) + return true + } + + if _, err := fmt.Fprintf(c.Writer, "event: response.failed\ndata: %s\n\n", payload); err != nil { + _ = c.Error(err) + return true + } + flusher.Flush() + return true +} + +// inboundIsResponses 判断当前请求是否落在任何 /responses 路由上。 +// +// 不能直接用 GetInboundEndpoint(c) == EndpointResponses 比较,因为 +// NormalizeInboundEndpoint 只识别包含 "/v1/responses" 子串的路径; +// 项目里实际注册了多组路由(gateway_v1、top-level bare、codex direct), +// 其中 r.POST("/responses", ...) 和 codexDirect.POST("/responses", ...) +// 的 c.FullPath() 不含 "/v1/" 前缀,会被归一化为原始路径, +// 导致协议合规终止事件没法发出去。 +// +// 这里用 FullPath 的后缀判断,覆盖所有变体: +// - /v1/responses +// - /v1/responses/compact +// - /responses +// - /responses/compact +// - /backend-api/codex/responses +// - /backend-api/codex/responses/compact +func inboundIsResponses(c *gin.Context) bool { + if c == nil { + return false + } + p := strings.TrimRight(c.FullPath(), "/") + if p == "" && c.Request != nil && c.Request.URL != nil { + p = strings.TrimRight(c.Request.URL.Path, "/") + } + if p == "" { + return false + } + return strings.HasSuffix(p, "/responses") || strings.Contains(p, "/responses/") +} + +// synthesizeResponseID 为合成的 response.failed 事件生成一个稳定的 id。 +// 优先复用 server 端生成的 request_id(存在 request.Context 里,由 request_logger 写入), +// 以便客户端报错能与 server 日志关联;缺失时回退 uuid。 +func synthesizeResponseID(c *gin.Context) string { + if c != nil && c.Request != nil { + if rid, ok := c.Request.Context().Value(ctxkey.RequestID).(string); ok { + if rid = strings.TrimSpace(rid); rid != "" { + return "resp_" + strings.ReplaceAll(rid, "-", "") + } + } + } + return "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") +} + +// requestModel 取当前请求的 inbound model(由 setOpsRequestContext 写入)。 +// 缺失时返回 "";caller 据此决定是否忽略该字段。 +func requestModel(c *gin.Context) string { + if c == nil { + return "" + } + if v, ok := c.Get(opsModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" +} + +// mapResponsesErrorCode 把内部 errType 映射为 Responses 协议常见的 error.code。 +// 无明确映射时原样返回,保证至少可读。 +func mapResponsesErrorCode(errType string) string { + switch errType { + case "rate_limit_error": + return "rate_limit_exceeded" + case "invalid_request_error": + return "invalid_request" + case "permission_error": + return "permission_denied" + case "authentication_error": + return "authentication_failed" + case "upstream_error": + return "upstream_error" + case "server_error", "api_error", "": + return "server_error" + default: + return errType + } +} diff --git a/backend/internal/handler/stream_error_event_test.go b/backend/internal/handler/stream_error_event_test.go new file mode 100644 index 00000000..f24cf97f --- /dev/null +++ b/backend/internal/handler/stream_error_event_test.go @@ -0,0 +1,253 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Regression for the production incident on 2026-05-24 around 9:13 CST: +// user 16 sent /v1/responses with stream:true via Codex CLI; the user-concurrency +// slot wait sent SSE ping comments (flushing HTTP 200 + headers), then the 30s +// timeout fired and the handler emitted `event: error\ndata: {...}`. Codex CLI +// does not recognize that as a Responses terminal event and reports +// "stream closed before response.completed". The fix is to emit a synthetic +// response.failed event when the inbound endpoint is /v1/responses. + +func newGinContextForEndpoint(t *testing.T, endpoint string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, endpoint, nil) + return c, w +} + +// parseResponsesFailedSSE 抽出 SSE 中 data 行的 JSON,返回 (response 对象, error 对象)。 +func parseResponsesFailedSSE(t *testing.T, body string) (map[string]any, map[string]any) { + t.Helper() + require.True(t, strings.HasPrefix(body, "event: response.failed\n"), + "expect event: response.failed prefix, got: %q", body) + require.True(t, strings.HasSuffix(body, "\n\n")) + + lines := strings.SplitN(strings.TrimSuffix(body, "\n\n"), "\n", 2) + require.Len(t, lines, 2) + require.True(t, strings.HasPrefix(lines[1], "data: ")) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data must be valid JSON: %s", jsonStr) + + assert.Equal(t, "response.failed", parsed["type"]) + // 故意不发 sequence_number,避免与后续真实事件的序号冲突。 + _, hasSeq := parsed["sequence_number"] + assert.False(t, hasSeq, "synthetic event must not emit sequence_number") + + resp, ok := parsed["response"].(map[string]any) + require.True(t, ok, "response object missing") + assert.Equal(t, "response", resp["object"]) + assert.Equal(t, "failed", resp["status"]) + + errObj, ok := resp["error"].(map[string]any) + require.True(t, ok, "error object missing") + + return resp, errObj +} + +// OpenAI handler: /v1/responses streaming, after stream started, must emit response.failed. +func TestOpenAIHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + "Concurrency limit exceeded for user, please retry later", true) + + resp, errObj := parseResponsesFailedSSE(t, w.Body.String()) + + id, _ := resp["id"].(string) + assert.True(t, strings.HasPrefix(id, "resp_"), "id should start with resp_, got %q", id) + assert.Equal(t, "rate_limit_exceeded", errObj["code"]) + assert.Equal(t, "Concurrency limit exceeded for user, please retry later", errObj["message"]) +} + +// 当 setOpsRequestContext 写过 model,合成事件应回填该字段(与 codebase 已有 makeResponsesCompletedEvent 对齐)。 +func TestOpenAIHandleStreamingAwareError_ResponsesStreamingIncludesModel(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + setOpsRequestContext(c, "gpt-5.5", true) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true) + + resp, _ := parseResponsesFailedSSE(t, w.Body.String()) + assert.Equal(t, "gpt-5.5", resp["model"]) +} + +// 没有 model 时 model 字段不应出现(避免发空字符串污染下游解析)。 +func TestOpenAIHandleStreamingAwareError_ResponsesStreamingOmitsEmptyModel(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true) + + resp, _ := parseResponsesFailedSSE(t, w.Body.String()) + _, hasModel := resp["model"] + assert.False(t, hasModel, "model field must be omitted when unknown") +} + +// 当 request.Context 携带 ctxkey.RequestID 时,合成 id 应与之关联,便于和 server log 串起来。 +func TestOpenAIHandleStreamingAwareError_ResponsesStreamingReusesRequestID(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + c.Request = c.Request.WithContext( + context.WithValue(c.Request.Context(), ctxkey.RequestID, "fd277bc5-ff7e-45d1-8aa9-f54e1df318f1"), + ) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "x", true) + + resp, _ := parseResponsesFailedSSE(t, w.Body.String()) + assert.Equal(t, "resp_fd277bc5ff7e45d18aa9f54e1df318f1", resp["id"]) +} + +// 与旧分支的 TestOpenAIHandleStreamingAwareError_JSONEscaping 对齐: +// 新的 response.failed payload 也必须正确转义 message 里的特殊字符, +// 否则下游 SDK 解析 JSON 时会失败。 +func TestOpenAIHandleStreamingAwareError_ResponsesStreamingJSONEscaping(t *testing.T) { + cases := []struct { + name string + errType string + message string + }{ + {"双引号", "server_error", `upstream returned "invalid" response`}, + {"反斜杠", "server_error", `path C:\Users\test\file.txt not found`}, + {"双引号+反斜杠", "upstream_error", `error parsing "key\value": unexpected token`}, + {"换行与制表", "server_error", "line1\nline2\ttab"}, + {"普通", "upstream_error", "Upstream service temporarily unavailable"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tc.errType, tc.message, true) + + _, errObj := parseResponsesFailedSSE(t, w.Body.String()) + assert.Equal(t, tc.message, errObj["message"], "message 必须被原样还原") + }) + } +} + +// OpenAI handler: /v1/chat/completions streaming keeps the legacy event: error format +// (out of scope for this fix; covered to prevent regression of unrelated paths). +func TestOpenAIHandleStreamingAwareError_ChatCompletionsStreamingKeepsLegacy(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointChatCompletions) + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true) + + body := w.Body.String() + assert.True(t, strings.HasPrefix(body, "event: error\n"), "got: %q", body) +} + +// Gateway (Anthropic-backed) handler: /v1/responses path also must emit response.failed. +func TestGatewayHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointResponses) + h := &GatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "upstream gone", true) + + _, errObj := parseResponsesFailedSSE(t, w.Body.String()) + assert.Equal(t, "upstream_error", errObj["code"]) + assert.Equal(t, "upstream gone", errObj["message"]) +} + +// Gateway handler: /v1/messages preserves the legacy data:{type:error,...} format +// (Anthropic spec accepts a type:"error" stream event). +func TestGatewayHandleStreamingAwareError_MessagesStreamingKeepsLegacy(t *testing.T) { + c, w := newGinContextForEndpoint(t, EndpointMessages) + h := &GatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true) + + body := w.Body.String() + assert.True(t, strings.HasPrefix(body, `data: {"type":"error"`), "got: %q", body) +} + +// 项目里 /responses 注册在多组路由:/v1/responses(gateway)、裸 /responses(top-level)、 +// /backend-api/codex/responses(codex direct)。我们 fix 必须覆盖全部, +// 否则一些客户端走的路径就不会发 response.failed,照样报 stream closed。 +// 这是生产 2026-05-24 ~11:05 UTC user 16 实际命中的 bug。 +func TestInboundIsResponses_CoversAllRoutes(t *testing.T) { + cases := []struct { + route string + want bool + }{ + {"/v1/responses", true}, + {"/v1/responses/compact", true}, + {"/responses", true}, // <-- 用户 16 实际走这条 + {"/responses/compact", true}, + {"/backend-api/codex/responses", true}, + {"/backend-api/codex/responses/compact", true}, + {"/v1/chat/completions", false}, + {"/v1/messages", false}, + {"/", false}, + {"/responses-fake", false}, + } + for _, tc := range cases { + t.Run(tc.route, func(t *testing.T) { + c, _ := newGinContextForEndpoint(t, tc.route) + assert.Equal(t, tc.want, inboundIsResponses(c), "route=%q", tc.route) + }) + } +} + +// 用 c.Request.URL.Path 作为 fallback(当 c.FullPath() 为空时,例如某些测试 fixture)。 +func TestInboundIsResponses_FallsBackToURLPath(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/responses", nil) + // 这种情况下 c.FullPath() 是 "",必须 fallback 到 URL.Path + assert.True(t, inboundIsResponses(c), "URL.Path fallback must work when FullPath is empty") +} + +// 回归生产事故:用户 16 走 /responses 路径,必须发 response.failed。 +func TestOpenAIHandleStreamingAwareError_BareResponsesRouteEmitsResponseFailed(t *testing.T) { + c, w := newGinContextForEndpoint(t, "/responses") + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + "Concurrency limit exceeded for user, please retry later", true) + + resp, errObj := parseResponsesFailedSSE(t, w.Body.String()) + id, _ := resp["id"].(string) + assert.True(t, strings.HasPrefix(id, "resp_")) + assert.Equal(t, "rate_limit_exceeded", errObj["code"]) +} + +// Synthesized response.failed id falls back to uuid when no request_id is present. +func TestSynthesizeResponseID_FallbackUUID(t *testing.T) { + c, _ := newGinContextForEndpoint(t, EndpointResponses) + id := synthesizeResponseID(c) + assert.True(t, strings.HasPrefix(id, "resp_")) + // uuid 去掉短横线后 32 hex 字符;前缀 "resp_" 共 37。 + assert.Len(t, id, 37) +} + +func TestMapResponsesErrorCode(t *testing.T) { + cases := []struct{ in, out string }{ + {"rate_limit_error", "rate_limit_exceeded"}, + {"invalid_request_error", "invalid_request"}, + {"permission_error", "permission_denied"}, + {"authentication_error", "authentication_failed"}, + {"upstream_error", "upstream_error"}, + {"server_error", "server_error"}, + {"api_error", "server_error"}, + {"", "server_error"}, + {"custom_thing", "custom_thing"}, + } + for _, tc := range cases { + assert.Equal(t, tc.out, mapResponsesErrorCode(tc.in), "in=%q", tc.in) + } +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 95cb1482..d4b10b2b 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -3,8 +3,10 @@ package handler import ( "context" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -14,11 +16,12 @@ import ( // UserHandler handles user-related requests type UserHandler struct { - userService *service.UserService - authService *service.AuthService - emailService *service.EmailService - emailCache service.EmailCache - affiliateService *service.AffiliateService + userService *service.UserService + authService *service.AuthService + emailService *service.EmailService + emailCache service.EmailCache + affiliateService *service.AffiliateService + userPlatformQuotaRepo service.UserPlatformQuotaRepository } // NewUserHandler creates a new UserHandler @@ -28,16 +31,44 @@ func NewUserHandler( emailService *service.EmailService, emailCache service.EmailCache, affiliateService *service.AffiliateService, + userPlatformQuotaRepo service.UserPlatformQuotaRepository, ) *UserHandler { return &UserHandler{ - userService: userService, - authService: authService, - emailService: emailService, - emailCache: emailCache, - affiliateService: affiliateService, + userService: userService, + authService: authService, + emailService: emailService, + emailCache: emailCache, + affiliateService: affiliateService, + userPlatformQuotaRepo: userPlatformQuotaRepo, } } +// GetMyPlatformQuotas GET /user/platform-quotas +// 返回当前 JWT 用户的 platform quota 状态。 +// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null(不写 DB) +func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.userPlatformQuotaRepo == nil { + response.Success(c, map[string]any{"platform_quotas": []any{}}) + return + } + records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + out := make([]map[string]any, 0, len(records)) + for _, r := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + // ChangePasswordRequest represents the change password request payload type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index fb690858..41647802 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } -func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil @@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) recorder := httptest.NewRecorder() @@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -285,7 +289,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) recorder := httptest.NewRecorder() @@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t * ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) recorder := httptest.NewRecorder() @@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/user_platform_quotas_handler_test.go b/backend/internal/handler/user_platform_quotas_handler_test.go new file mode 100644 index 00000000..96d13fd3 --- /dev/null +++ b/backend/internal/handler/user_platform_quotas_handler_test.go @@ -0,0 +1,212 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集 +type fakeQuotaRepoForUserHandler struct { + service.UserPlatformQuotaRepository + records []service.UserPlatformQuotaRecord +} + +func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return f.records, nil +} + +func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) { + repo := &fakeQuotaRepoForUserHandler{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + h.GetMyPlatformQuotas(c) + if w.Code != 200 { + t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String()) + } + var body struct { + Code int `json:"code"` + Data struct { + PlatformQuotas []any `json:"platform_quotas"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String()) + } + if body.Code != 0 { + t.Errorf("expected code=0, got %d", body.Code) + } + if body.Data.PlatformQuotas == nil { + // nil 和 empty slice 均视为可接受(JSON 可能序列化为 null 或 []) + // 此断言只验证 HTTP 200 + code=0 即可 + } +} + +func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) { + pastStart := time.Now().UTC().AddDate(0, 0, -2) + daily := 5.0 + repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{ + UserID: 42, + Platform: "anthropic", + DailyLimitUSD: &daily, + DailyUsageUSD: 3.0, + DailyWindowStart: &pastStart, + }}} + h := &UserHandler{userPlatformQuotaRepo: repo} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + h.GetMyPlatformQuotas(c) + + if w.Code != 200 { + t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String()) + } + + // 解析 response,验证过期 daily 的 usage_usd=0 且 window_resets_at=null + body := w.Body.String() + if !strings.Contains(body, `"daily_usage_usd":0`) { + t.Errorf("expected daily_usage_usd:0 in body, got: %s", body) + } + if !strings.Contains(body, `"daily_window_resets_at":null`) { + t.Errorf("expected daily_window_resets_at:null in body, got: %s", body) + } +} + +func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99}) + h.GetMyPlatformQuotas(c) + if w.Code != 200 { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + // 不设置 auth subject + h.GetMyPlatformQuotas(c) + if w.Code != 401 { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) { + start := time.Now().UTC().Add(-1 * time.Hour) + r := service.UserPlatformQuotaRecord{ + Platform: "anthropic", + DailyUsageUSD: 1.0, + DailyWindowStart: &start, + } + out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false) + if _, ok := out["daily_window_start"]; ok { + t.Error("user view should not include daily_window_start") + } +} + +func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) { + start := time.Now().UTC().Add(-1 * time.Hour) + r := service.UserPlatformQuotaRecord{ + Platform: "anthropic", + DailyWindowStart: &start, + } + out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true) + if _, ok := out["daily_window_start"]; !ok { + t.Error("admin view should include daily_window_start") + } +} + +func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) { + // 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径 + now := time.Now() + today := timezone.StartOfDay(now) + usage := 2.5 + r := service.UserPlatformQuotaRecord{ + Platform: "openai", + DailyUsageUSD: usage, + DailyWindowStart: &today, + } + out := quotaview.LazyZeroQuotaForResponse(r, now, false) + if out["daily_usage_usd"] != usage { + t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"]) + } + // 活跃窗口应有 resets_at(非 nil) + if out["daily_window_resets_at"] == nil { + t.Error("active window should have daily_window_resets_at set") + } +} + +func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsDailyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need reset") + } +} + +func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) { + old := time.Now().UTC().AddDate(0, 0, -1) + if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) { + t.Error("yesterday start should need daily reset") + } +} + +func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need weekly reset") + } +} + +func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need monthly reset") + } +} + +// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义(C-NEW-1)。 +func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) { + start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期 + if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) { + t.Error("31 days ago should need monthly reset (30-day rolling)") + } +} + +func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) { + start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效 + if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) { + t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)") + } +} + +// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。 +func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) { + // 窗口起始 4 月 20 日;5 月 1 日仅过了 11 天,不足 30 天,不应重置 + windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + if quotaview.NeedsMonthlyReset(&windowStart, now) { + t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)") + } +} diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index 7490654d..bb566081 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -517,6 +517,33 @@ func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) { assert.Nil(t, FinalizeResponsesAnthropicStream(state)) } +func TestResponsesEventToAnthropicEvents_TopLevelTerminalUsage(t *testing.T) { + state := NewResponsesEventToAnthropicState() + state.Model = "gpt-4o" + + events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + Usage: &ResponsesUsage{ + InputTokens: 20, + OutputTokens: 6, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 5, + }, + }, + }, state) + + require.Len(t, events, 2) + assert.Equal(t, "message_delta", events[0].Type) + require.NotNil(t, events[0].Usage) + assert.Equal(t, 15, events[0].Usage.InputTokens) + assert.Equal(t, 5, events[0].Usage.CacheReadInputTokens) + assert.Equal(t, 6, events[0].Usage.OutputTokens) + assert.Equal(t, "message_stop", events[1].Type) +} + func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) { state := NewResponsesEventToAnthropicState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index ad26f273..016c2415 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -846,6 +846,33 @@ func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) { assert.Nil(t, FinalizeResponsesChatStream(state)) } +func TestResponsesEventToChatChunks_TopLevelTerminalUsage(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + Usage: &ResponsesUsage{ + InputTokens: 21, + OutputTokens: 9, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 4, + }, + }, + }, state) + + require.Len(t, chunks, 2) + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 21, chunks[1].Usage.PromptTokens) + assert.Equal(t, 9, chunks[1].Usage.CompletionTokens) + require.NotNil(t, chunks[1].Usage.PromptTokensDetails) + assert.Equal(t, 4, chunks[1].Usage.PromptTokensDetails.CachedTokens) +} + func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) { state := NewResponsesEventToChatState() state.Model = "gpt-4o" diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go index d7ef0145..6913f2eb 100644 --- a/backend/internal/pkg/apicompat/responses_to_anthropic.go +++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go @@ -567,6 +567,12 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo events = append(events, closeCurrentBlock(state)...) stopReason := "end_turn" + if evt.Usage != nil { + usage := anthropicUsageFromResponsesUsage(evt.Usage) + state.InputTokens = usage.InputTokens + state.OutputTokens = usage.OutputTokens + state.CacheReadInputTokens = usage.CacheReadInputTokens + } if evt.Response != nil { if evt.Response.Usage != nil { usage := anthropicUsageFromResponsesUsage(evt.Response.Usage) diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 2386771d..7e8354ee 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -293,20 +293,12 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo state.Finalized = true finishReason := "stop" + if evt.Usage != nil { + state.Usage = chatUsageFromResponsesUsage(evt.Usage) + } if evt.Response != nil { if evt.Response.Usage != nil { - u := evt.Response.Usage - usage := &ChatUsage{ - PromptTokens: u.InputTokens, - CompletionTokens: u.OutputTokens, - TotalTokens: u.InputTokens + u.OutputTokens, - } - if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { - usage.PromptTokensDetails = &ChatTokenDetails{ - CachedTokens: u.InputTokensDetails.CachedTokens, - } - } - state.Usage = usage + state.Usage = chatUsageFromResponsesUsage(evt.Response.Usage) } switch evt.Response.Status { @@ -340,6 +332,23 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo return chunks } +func chatUsageFromResponsesUsage(u *ResponsesUsage) *ChatUsage { + if u == nil { + return nil + } + usage := &ChatUsage{ + PromptTokens: u.InputTokens, + CompletionTokens: u.OutputTokens, + TotalTokens: u.InputTokens + u.OutputTokens, + } + if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: u.InputTokensDetails.CachedTokens, + } + } + return usage +} + func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk { return ChatCompletionsChunk{ ID: state.ID, diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index 7c46ccaf..8b576647 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -380,6 +380,8 @@ type ResponsesStreamEvent struct { // response.created / response.completed / response.done / response.failed / response.incomplete Response *ResponsesResponse `json:"response,omitempty"` + // 部分 OpenAI 兼容上游会把 usage 放在终止事件顶层,而不是 response.usage。 + Usage *ResponsesUsage `json:"usage,omitempty"` // response.output_item.added / response.output_item.done Item *ResponsesOutput `json:"item,omitempty"` diff --git a/backend/internal/pkg/timezone/timezone_test.go b/backend/internal/pkg/timezone/timezone_test.go index ac9cdde6..610b1bb9 100644 --- a/backend/internal/pkg/timezone/timezone_test.go +++ b/backend/internal/pkg/timezone/timezone_test.go @@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) { _ = Now() _ = StartOfDay(Now()) } + +func TestStartOfWeek_Boundaries(t *testing.T) { + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = Init("UTC") }) + + loc := Location() + wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一 + + cases := []struct { + name string + in time.Time + }{ + {"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)}, + {"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)}, + {"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := StartOfWeek(c.in); !got.Equal(wantMon) { + t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon) + } + }) + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index e62c4e52..30f1fc2e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1085,7 +1085,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return nil } -func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { +func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { if scope == "" { return nil } @@ -1094,6 +1094,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco "rate_limited_at": now.Format(time.RFC3339), "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), } + if len(reason) > 0 { + if value := strings.TrimSpace(reason[0]); value != "" { + payload["reason"] = value + } + } raw, err := json.Marshal(payload) if err != nil { return err @@ -1129,6 +1134,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 43b13937..bfe09283 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -183,6 +183,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, group.FieldMessagesDispatchModelConfig, + group.FieldModelsListConfig, group.FieldRpmLimit, ) }). @@ -723,6 +724,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + ModelsListConfig: g.ModelsListConfig, RPMLimit: g.RpmLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 6922b4c8..60dae954 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6 key := billingRateLimitKey(keyID) return c.rdb.Del(ctx, key).Err() } + +// ============================================ +// user × platform quota 缓存 +// ============================================ + +// userPlatformQuotaCacheKey 构造 Redis key +func userPlatformQuotaCacheKey(userID int64, platform string) string { + return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform) +} + +func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) { + key := userPlatformQuotaCacheKey(userID, platform) + fields := []string{ + "daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version", + "daily_limit", "weekly_limit", "monthly_limit", + "daily_window_start", "weekly_window_start", "monthly_window_start", + } + vals, err := c.rdb.HMGet(ctx, key, fields...).Result() + if err != nil { + return nil, false, err + } + // 前4个全为nil → key 不存在 + if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil { + return nil, false, nil + } + parseFloat := func(v any) float64 { + if v == nil { + return 0 + } + s, ok := v.(string) + if !ok { + return 0 + } + f, _ := strconv.ParseFloat(s, 64) + return f + } + parseFloatPtr := func(v any) *float64 { + if v == nil { + return nil + } + s, ok := v.(string) + if !ok || s == "" { + return nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil + } + return &f + } + parseTimePtr := func(v any) *time.Time { + if v == nil { + return nil + } + s, ok := v.(string) + if !ok || s == "" { + return nil + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil + } + t := time.Unix(n, 0).UTC() + return &t + } + parseInt64 := func(v any) int64 { + if v == nil { + return 0 + } + s, ok := v.(string) + if !ok { + return 0 + } + n, _ := strconv.ParseInt(s, 10, 64) + return n + } + return &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: parseFloat(vals[0]), + WeeklyUsageUSD: parseFloat(vals[1]), + MonthlyUsageUSD: parseFloat(vals[2]), + Version: parseInt64(vals[3]), + SchemaVersion: parseInt64(vals[4]), + DailyLimitUSD: parseFloatPtr(vals[5]), + WeeklyLimitUSD: parseFloatPtr(vals[6]), + MonthlyLimitUSD: parseFloatPtr(vals[7]), + DailyWindowStart: parseTimePtr(vals[8]), + WeeklyWindowStart: parseTimePtr(vals[9]), + MonthlyWindowStart: parseTimePtr(vals[10]), + }, true, nil +} + +func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + key := userPlatformQuotaCacheKey(userID, platform) + pipe := c.rdb.TxPipeline() + + // 浮点可空字段:nil → 空字符串(读取时 parseFloatPtr 返回 nil,表示无限额) + fmtFloatPtr := func(p *float64) string { + if p == nil { + return "" + } + return strconv.FormatFloat(*p, 'f', -1, 64) + } + // time.Time 可空字段:nil → 空字符串;有值 → unix 秒 + fmtTimePtr := func(p *time.Time) string { + if p == nil { + return "" + } + return strconv.FormatInt(p.Unix(), 10) + } + + pipe.HSet(ctx, key, + "daily_usage", entry.DailyUsageUSD, + "weekly_usage", entry.WeeklyUsageUSD, + "monthly_usage", entry.MonthlyUsageUSD, + "version", entry.Version, + "schema_version", entry.SchemaVersion, + "daily_limit", fmtFloatPtr(entry.DailyLimitUSD), + "weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD), + "monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD), + "daily_window_start", fmtTimePtr(entry.DailyWindowStart), + "weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart), + "monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart), + ) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err() +} + +// updateUserPlatformQuotaUsageScript 缓存累加:EXISTS + schema_version 双重守卫。 +// 旧版 entry(schema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后 +// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。 +// key 不存在同样跳过(由下次 SetCache 重建)。 +// KEYS[1] = hash key +// ARGV[1] = cost (string float) +// ARGV[2] = ttl seconds +// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1) +const updateUserPlatformQuotaUsageScript = ` +if redis.call("EXISTS", KEYS[1]) == 0 then + return 0 +end +local ver = redis.call("HGET", KEYS[1], "schema_version") +if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then + return 0 +end +redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1]) +redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1]) +redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1]) +redis.call("HINCRBY", KEYS[1], "version", 1) +redis.call("EXPIRE", KEYS[1], ARGV[2]) +return 1 +` + +func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + key := userPlatformQuotaCacheKey(userID, platform) + _, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key}, + strconv.FormatFloat(cost, 'f', -1, 64), + int(ttl.Seconds()), + service.UserPlatformQuotaCacheSchemaV1, + ).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} diff --git a/backend/internal/repository/billing_cache_user_platform_quota_test.go b/backend/internal/repository/billing_cache_user_platform_quota_test.go new file mode 100644 index 00000000..8d49fd31 --- /dev/null +++ b/backend/internal/repository/billing_cache_user_platform_quota_test.go @@ -0,0 +1,134 @@ +//go:build unit + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return &billingCache{rdb: rdb}, mr +} + +func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) { + c, _ := newMiniRedisCache(t) + entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic") + if err != nil { + t.Fatal(err) + } + if ok || entry != nil { + t.Errorf("expected miss, got ok=%v entry=%v", ok, entry) + } +} + +func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + dailyLimit := 20.0 + ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC) + in := &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 1.5, + WeeklyUsageUSD: 3.0, + MonthlyUsageUSD: 10.0, + Version: 7, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: &dailyLimit, + DailyWindowStart: &ts, + } + if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil { + t.Fatal(err) + } + got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if err != nil || !ok { + t.Fatalf("get: ok=%v err=%v", ok, err) + } + if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 { + t.Errorf("got = %+v, want %+v", got, in) + } + if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 { + t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1) + } + if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit { + t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit) + } + if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) { + t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts) + } +} + +func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + in := &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 1.0, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + // DailyLimitUSD nil → 无限额 + } + if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil { + t.Fatal(err) + } + got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if err != nil || !ok { + t.Fatalf("get: ok=%v err=%v", ok, err) + } + if got.DailyLimitUSD != nil { + t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD) + } +} + +func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) { + c, _ := newMiniRedisCache(t) + if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil { + t.Fatal(err) + } + _, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai") + if ok { + t.Error("expected key absent after no-op incr") + } +} + +func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + // SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。 + _ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{ + Version: 1, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + }, time.Minute) + if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil { + t.Fatal(err) + } + if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil { + t.Fatal(err) + } + got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 { + t.Errorf("got %+v, want daily/weekly/monthly=0.75", got) + } + if got.Version != 3 { + t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version) + } +} + +func TestUserPlatformQuotaCache_Delete(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + _ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute) + if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil { + t.Fatal(err) + } + _, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if ok { + t.Error("expected miss after delete") + } +} diff --git a/backend/internal/repository/content_moderation_repo.go b/backend/internal/repository/content_moderation_repo.go index 6ada004a..9b19cce9 100644 --- a/backend/internal/repository/content_moderation_repo.go +++ b/backend/internal/repository/content_moderation_repo.go @@ -192,6 +192,7 @@ SELECT COUNT(*) FROM content_moderation_logs WHERE user_id = $1 AND flagged = TRUE + AND action <> 'hash_block' AND created_at >= $2 AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz) `, userID, since).Scan(&count) @@ -246,7 +247,7 @@ func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) ( case "hit", "flagged": where = append(where, "l.flagged = TRUE") case "blocked", "block": - where = append(where, "l.action = 'block'") + where = append(where, "l.action IN ('block', 'keyword_block', 'hash_block')") case "pass", "allow": where = append(where, "l.flagged = FALSE AND l.error = ''") case "error": diff --git a/backend/internal/repository/content_moderation_repo_test.go b/backend/internal/repository/content_moderation_repo_test.go new file mode 100644 index 00000000..6d5faa12 --- /dev/null +++ b/backend/internal/repository/content_moderation_repo_test.go @@ -0,0 +1,40 @@ +package repository + +import ( + "context" + "regexp" + "strings" + "testing" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBuildContentModerationLogWhere_BlockedIncludesAllBlockActions(t *testing.T) { + where, args := buildContentModerationLogWhere(service.ContentModerationLogFilter{Result: "blocked"}) + + require.Empty(t, args) + sql := strings.Join(where, " AND ") + require.Contains(t, sql, "l.action IN ('block', 'keyword_block', 'hash_block')") + require.NotContains(t, sql, "l.action = 'block'") +} + +func TestContentModerationRepositoryCountFlaggedByUserSince_ExcludesHashBlock(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + repo := NewContentModerationRepository(db) + since := time.Now().Add(-time.Hour) + mock.ExpectQuery(regexp.QuoteMeta("AND action <> 'hash_block'")). + WithArgs(int64(1001), since). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(2)) + + count, err := repo.CountFlaggedByUserSince(context.Background(), 1001, since) + + require.NoError(t, err) + require.Equal(t, 2, count) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9c3b2010..ac8669ab 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -66,6 +66,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetModelsListConfig(groupIn.ModelsListConfig). SetRpmLimit(groupIn.RPMLimit) // 设置模型路由配置 @@ -141,6 +142,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetModelsListConfig(groupIn.ModelsListConfig). SetRpmLimit(groupIn.RPMLimit) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index a07de195..579add6c 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -3,6 +3,8 @@ package repository import ( "compress/flate" "compress/gzip" + "context" + "crypto/tls" "errors" "fmt" "io" @@ -50,6 +52,17 @@ const ( defaultMaxUpstreamClients = 5000 // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) defaultClientIdleTTLSeconds = 900 + // OpenAI HTTP/2 代理回退策略默认值 + defaultOpenAIHTTP2FallbackErrorThreshold = 2 + defaultOpenAIHTTP2FallbackWindow = 60 * time.Second + defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute +) + +const ( + upstreamProtocolModeDefault = "default" + upstreamProtocolModeOpenAIH1 = "openai_h1" + upstreamProtocolModeOpenAIH2 = "openai_h2" + upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback" ) var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") @@ -64,14 +77,30 @@ type poolSettings struct { responseHeaderTimeout time.Duration // 等待响应头超时时间 } +type openAIHTTP2Settings struct { + enabled bool + allowProxyFallbackToHTTP1 bool + fallbackErrorThreshold int + fallbackWindow time.Duration + fallbackTTL time.Duration +} + // upstreamClientEntry 上游客户端缓存条目 // 记录客户端实例及其元数据,用于连接池管理和淘汰策略 type upstreamClientEntry struct { - client *http.Client // HTTP 客户端实例 - proxyKey string // 代理标识(用于检测代理变更) - poolKey string // 连接池配置标识(用于检测配置变更) - lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 - inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 + client *http.Client // HTTP 客户端实例 + proxyKey string // 代理标识(用于检测代理变更) + poolKey string // 连接池配置标识(用于检测配置变更) + protocolMode string // 协议模式(default/openai_h1/openai_h2/openai_h1_fallback) + lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 + inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 +} + +type openAIHTTP2FallbackState struct { + mu sync.Mutex + windowStart time.Time + errorCount int + fallbackUntil time.Time } // httpUpstreamService 通用 HTTP 上游服务 @@ -95,6 +124,8 @@ type httpUpstreamService struct { cfg *config.Config // 全局配置 mu sync.RWMutex // 保护 clients map 的读写锁 clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 + // OpenAI 走 HTTP/HTTPS 代理时的 H2->H1 回退状态(key=标准化 proxyKey) + openAIHTTP2Fallbacks sync.Map } // NewHTTPUpstream 创建通用 HTTP 上游服务 @@ -142,9 +173,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i if err := s.validateRequestHost(req); err != nil { return nil, err } + profile := service.HTTPUpstreamProfileDefault + if req != nil { + profile = service.HTTPUpstreamProfileFromContext(req.Context()) + } // 获取或创建对应的客户端,并标记请求占用 - entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) + entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile) if err != nil { return nil, err } @@ -152,11 +187,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i // 执行请求 resp, err := entry.client.Do(req) if err != nil { + s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err) // 请求失败,立即减少计数 atomic.AddInt64(&entry.inFlight, -1) atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) return nil, err } + s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey) // 如果上游返回了压缩内容,解压后再交给业务层 decompressResponseBody(resp) @@ -179,6 +216,10 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco if profile == nil { return s.Do(req, proxyURL, accountID, accountConcurrency) } + upstreamProfile := service.HTTPUpstreamProfileDefault + if req != nil { + upstreamProfile = service.HTTPUpstreamProfileFromContext(req.Context()) + } targetHost := "" if req != nil && req.URL != nil { @@ -194,7 +235,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco return nil, err } - entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile) + entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile) if err != nil { slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err) return nil, err @@ -219,21 +260,23 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco } // acquireClientWithTLS 获取或创建带 TLS 指纹的客户端 -func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) { - return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true) +func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) { + return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile, true, true) } // getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目 // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 -func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { +func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) if err != nil { return nil, err } + settings := s.resolvePoolSettings(isolation, accountConcurrency) + settings = s.applyProfilePoolSettings(settings, upstreamProfile) // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 - cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" + cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault) + poolKey := buildPoolKey(settings, upstreamProtocolModeDefault) + ":tls" now := time.Now() nowUnix := now.UnixNano() @@ -284,7 +327,6 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i // 创建带 TLS 指纹的 Transport slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", logredact.RedactProxyURL(proxyKey)) - settings := s.resolvePoolSettings(isolation, accountConcurrency) transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile) if err != nil { s.mu.Unlock() @@ -350,7 +392,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req // acquireClient 获取或创建客户端,并标记为进行中请求 // 用于请求路径,避免在获取后被淘汰 func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) + return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault) +} + +// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。 +func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true) } // getOrCreateClient 获取或创建客户端 @@ -369,13 +416,13 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) + return s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false) } // getClientEntry 获取或创建客户端条目 // markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 // enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 -func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 @@ -383,10 +430,14 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a if err != nil { return nil, err } + // 根据请求 profile(例如 OpenAI)选择协议模式 + protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy) + settings := s.resolvePoolSettings(isolation, accountConcurrency) + settings = s.applyProfilePoolSettings(settings, profile) // 构建缓存键(根据隔离策略不同) - cacheKey := buildCacheKey(isolation, proxyKey, accountID) + cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode) // 构建连接池配置键(用于检测配置变更) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + poolKey := buildPoolKey(settings, protocolMode) now := time.Now() nowUnix := now.UnixNano() @@ -429,8 +480,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a } // 缓存未命中或需要重建,创建新客户端 - settings := s.resolvePoolSettings(isolation, accountConcurrency) - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode) if err != nil { s.mu.Unlock() return nil, fmt.Errorf("build transport: %w", err) @@ -440,9 +490,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a client.CheckRedirect = s.redirectChecker } entry := &upstreamClientEntry{ - client: client, - proxyKey: proxyKey, - poolKey: poolKey, + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + protocolMode: protocolMode, } atomic.StoreInt64(&entry.lastUsed, nowUnix) if markInFlight { @@ -626,22 +677,31 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu return settings } -// buildPoolKey 构建连接池配置键 -// 用于检测配置变更,配置变更时需要重建客户端 -// -// 参数: -// - isolation: 隔离模式 -// - accountConcurrency: 账户并发限制 -// -// 返回: -// - string: 配置键 -func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { - if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { - if accountConcurrency > 0 { - return fmt.Sprintf("account:%d", accountConcurrency) - } +func (s *httpUpstreamService) applyProfilePoolSettings(settings poolSettings, profile service.HTTPUpstreamProfile) poolSettings { + if profile != service.HTTPUpstreamProfileOpenAI { + return settings } - return "default" + settings.responseHeaderTimeout = 0 + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIResponseHeaderTimeout > 0 { + settings.responseHeaderTimeout = time.Duration(s.cfg.Gateway.OpenAIResponseHeaderTimeout) * time.Second + } + return settings +} + +// buildPoolKey 构建连接池配置键,用于检测连接池配置变更。 +func buildPoolKey(settings poolSettings, protocolMode string) string { + base := fmt.Sprintf( + "idle:%d|idle_host:%d|max:%d|idle_timeout:%s|header_timeout:%s", + settings.maxIdleConns, + settings.maxIdleConnsPerHost, + settings.maxConnsPerHost, + settings.idleConnTimeout, + settings.responseHeaderTimeout, + ) + if protocolMode == "" || protocolMode == upstreamProtocolModeDefault { + return base + } + return base + "|proto:" + protocolMode } // buildCacheKey 构建客户端缓存键 @@ -659,15 +719,245 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency // - proxy 模式: "proxy:{proxyKey}" // - account 模式: "account:{accountID}" // - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" -func buildCacheKey(isolation, proxyKey string, accountID int64) string { +func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string { + var base string switch isolation { case config.ConnectionPoolIsolationAccount: - return fmt.Sprintf("account:%d", accountID) + base = fmt.Sprintf("account:%d", accountID) case config.ConnectionPoolIsolationAccountProxy: - return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) + base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) default: - return fmt.Sprintf("proxy:%s", proxyKey) + base = fmt.Sprintf("proxy:%s", proxyKey) } + if protocolMode != "" && protocolMode != upstreamProtocolModeDefault { + base += "|proto:" + protocolMode + } + return base +} + +func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings { + settings := openAIHTTP2Settings{ + enabled: false, + allowProxyFallbackToHTTP1: true, + fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold, + fallbackWindow: defaultOpenAIHTTP2FallbackWindow, + fallbackTTL: defaultOpenAIHTTP2FallbackTTL, + } + if s == nil || s.cfg == nil { + return settings + } + cfg := s.cfg.Gateway.OpenAIHTTP2 + settings.enabled = cfg.Enabled + settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1 + if cfg.FallbackErrorThreshold > 0 { + settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold + } + if cfg.FallbackWindowSeconds > 0 { + settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second + } + if cfg.FallbackTTLSeconds > 0 { + settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second + } + return settings +} + +func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string { + if profile != service.HTTPUpstreamProfileOpenAI { + return upstreamProtocolModeDefault + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled { + return upstreamProtocolModeOpenAIH1 + } + if parsedProxy == nil { + return upstreamProtocolModeOpenAIH2 + } + scheme := strings.ToLower(parsedProxy.Scheme) + if scheme != "http" && scheme != "https" { + return upstreamProtocolModeOpenAIH2 + } + if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) { + return upstreamProtocolModeOpenAIH1Fallback + } + return upstreamProtocolModeOpenAIH2 +} + +func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool { + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return false + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return false + } + return state.isFallbackActive(time.Now()) +} + +func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState { + state := &openAIHTTP2FallbackState{} + actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state) + cached, ok := actual.(*openAIHTTP2FallbackState) + if !ok || cached == nil { + return state + } + return cached +} + +func isHTTPProxyKey(proxyKey string) bool { + return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://") +} + +func isOpenAIHTTP2CompatibilityError(err error) bool { + if err == nil { + return false + } + if isUpstreamTimeoutError(err) { + return false + } + msg := strings.ToLower(err.Error()) + if msg == "" { + return false + } + markers := []string{ + "alpn", + "no application protocol", + "protocol error", + "stream error", + "goaway", + "refused_stream", + "frame too large", + } + for _, marker := range markers { + if strings.Contains(msg, marker) { + return true + } + } + return false +} + +func isUpstreamTimeoutError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + msg := strings.ToLower(err.Error()) + if msg == "" { + return false + } + timeoutMarkers := []string{ + "timeout awaiting response headers", + "i/o timeout", + "context deadline exceeded", + "client.timeout exceeded while awaiting headers", + "tls handshake timeout", + } + for _, marker := range timeoutMarkers { + if strings.Contains(msg, marker) { + return true + } + } + return false +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled || !settings.allowProxyFallbackToHTTP1 { + return + } + if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) { + return + } + state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey) + activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL) + if activated { + slog.Warn("openai_http2_proxy_fallback_activated", + "proxy", proxyKey, + "fallback_until", until.Format(time.RFC3339)) + } +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + if !isHTTPProxyKey(proxyKey) { + return + } + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return + } + state.resetErrorWindow() +} + +func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.fallbackUntil.IsZero() { + return false + } + if now.Before(s.fallbackUntil) { + return true + } + s.fallbackUntil = time.Time{} + return false +} + +func (s *openAIHTTP2FallbackState) resetErrorWindow() { + s.mu.Lock() + defer s.mu.Unlock() + s.windowStart = time.Time{} + s.errorCount = 0 +} + +func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) { + if threshold <= 0 { + threshold = defaultOpenAIHTTP2FallbackErrorThreshold + } + if window <= 0 { + window = defaultOpenAIHTTP2FallbackWindow + } + if ttl <= 0 { + ttl = defaultOpenAIHTTP2FallbackTTL + } + + s.mu.Lock() + defer s.mu.Unlock() + + if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) { + return false, s.fallbackUntil + } + if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) { + s.fallbackUntil = time.Time{} + } + + if s.windowStart.IsZero() || now.Sub(s.windowStart) > window { + s.windowStart = now + s.errorCount = 0 + } + s.errorCount++ + if s.errorCount < threshold { + return false, time.Time{} + } + + s.fallbackUntil = now.Add(ttl) + s.windowStart = time.Time{} + s.errorCount = 0 + return true, s.fallbackUntil } // normalizeProxyURL 标准化代理 URL @@ -739,7 +1029,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { if cfg.Gateway.IdleConnTimeoutSeconds > 0 { idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second } - if cfg.Gateway.ResponseHeaderTimeout > 0 { + if cfg.Gateway.ResponseHeaderTimeout >= 0 { responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second } } @@ -770,7 +1060,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) // - IdleConnTimeout: 空闲连接超时(超时后关闭) // - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) { transport := &http.Transport{ MaxIdleConns: settings.maxIdleConns, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, @@ -778,6 +1068,17 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra IdleConnTimeout: settings.idleConnTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout, } + switch protocolMode { + case upstreamProtocolModeOpenAIH2: + transport.ForceAttemptHTTP2 = true + case upstreamProtocolModeOpenAIH1: + transport.ForceAttemptHTTP2 = false + transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + case upstreamProtocolModeOpenAIH1Fallback: + // 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。 + transport.ForceAttemptHTTP2 = false + transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + } if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { return nil, err } diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 89892b3b..a92105c1 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -45,7 +45,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { // 每次迭代都创建新客户端,包含 Transport 分配 - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeDefault) if err != nil { b.Fatalf("创建 Transport 失败: %v", err) } diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 44eca078..63d2b786 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -1,6 +1,7 @@ package repository import ( + "errors" "io" "net/http" "sync/atomic" @@ -8,6 +9,8 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -41,12 +44,24 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { } // TestDefaultResponseHeaderTimeout 测试默认响应头超时配置 -// 验证未配置时使用 600 秒默认值(LLM 排队较久,本地从 300s 提至 600s) +// 验证显式 0 会禁用等待响应头超时 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), time.Duration(0), transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") +} + +// TestNilConfigResponseHeaderTimeoutFallback 验证 nil 配置使用代码级兜底值。 +// 本地分叉将默认值从 300s 提至 600s(LLM 排队较久)。 +func (s *HTTPUpstreamSuite) TestNilConfigResponseHeaderTimeoutFallback() { + up := NewHTTPUpstream(nil) + svc, ok := up.(*httpUpstreamService) + require.True(s.T(), ok, "expected *httpUpstreamService") + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 600*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } @@ -65,10 +80,130 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { // 验证解析失败时拒绝回退到直连模式 func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, service.HTTPUpstreamProfileDefault, false, false) require.Error(s.T(), err, "expected error for invalid proxy URL") } +func (s *HTTPUpstreamSuite) TestOpenAIProfileDefaultsToHTTP2AndNoHeaderTimeout() { + s.cfg.Gateway = config.GatewayConfig{ + ResponseHeaderTimeout: 600, + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + }, + } + svc := s.newService() + entry, err := svc.getClientEntry("", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), time.Duration(0), transport.ResponseHeaderTimeout, "OpenAI profile should not inherit generic header timeout") + require.True(s.T(), transport.ForceAttemptHTTP2, "OpenAI profile should prefer HTTP/2") + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, entry.protocolMode) +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfileCustomHeaderTimeout() { + s.cfg.Gateway = config.GatewayConfig{ + ResponseHeaderTimeout: 600, + OpenAIResponseHeaderTimeout: 1800, + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + }, + } + svc := s.newService() + entry, err := svc.getClientEntry("", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), 1800*time.Second, transport.ResponseHeaderTimeout) +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfileTLSFingerprintDoesNotInheritGenericHeaderTimeout() { + s.cfg.Gateway = config.GatewayConfig{ + ResponseHeaderTimeout: 600, + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + }, + } + svc := s.newService() + entry, err := svc.getClientEntryWithTLS("", 1, 1, &tlsfingerprint.Profile{Name: "test"}, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), time.Duration(0), transport.ResponseHeaderTimeout, "OpenAI TLS path should not inherit generic header timeout") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfileHTTP2DisabledUsesHTTP1Transport() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{Enabled: false}, + } + svc := s.newService() + entry, err := svc.getClientEntry("", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.False(s.T(), transport.ForceAttemptHTTP2, "OpenAI HTTP/2 disabled should not force H2") + require.NotNil(s.T(), transport.TLSNextProto, "HTTP/1 mode should disable automatic H2 negotiation") + require.Equal(s.T(), upstreamProtocolModeOpenAIH1, entry.protocolMode) +} + +func (s *HTTPUpstreamSuite) TestOpenAIHeaderTimeoutChangeRebuildsClient() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{Enabled: true}, + } + svc := s.newService() + entry1, err := svc.getClientEntry("", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + + s.cfg.Gateway.OpenAIResponseHeaderTimeout = 1800 + entry2, err := svc.getClientEntry("", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + require.NotSame(s.T(), entry1, entry2, "OpenAI header timeout changes must rebuild cached client") + transport, ok := entry2.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), 1800*time.Second, transport.ResponseHeaderTimeout) +} + +func (s *HTTPUpstreamSuite) TestOpenAIHTTP2TimeoutDoesNotActivateProxyFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 1, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("http2: timeout awaiting response headers")) + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "header timeout should not be treated as H2 compatibility failure") +} + +func (s *HTTPUpstreamSuite) TestOpenAIHTTP2ProxyCompatibilityErrorActivatesFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 1, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("http2: protocol error")) + require.True(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL)) + + entry, err := svc.getClientEntry(proxyURL, 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.False(s.T(), transport.ForceAttemptHTTP2) + require.NotNil(s.T(), transport.TLSNextProto) + require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, entry.protocolMode) +} + // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { diff --git a/backend/internal/repository/user_platform_quota_adapter_test.go b/backend/internal/repository/user_platform_quota_adapter_test.go new file mode 100644 index 00000000..a55d2e9c --- /dev/null +++ b/backend/internal/repository/user_platform_quota_adapter_test.go @@ -0,0 +1,91 @@ +//go:build unit + +package repository + +import ( + "context" + "errors" + "testing" + "time" +) + +type fakeRepoForAdapter struct { + upsertCalledWith []UserPlatformQuotaRecord + upsertCalledUserID int64 + upsertErr error + resetCalledWith [4]any // userID, platform, window, newStart + resetErr error +} + +func (f *fakeRepoForAdapter) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error { + return nil +} +func (f *fakeRepoForAdapter) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return nil, nil +} +func (f *fakeRepoForAdapter) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} +func (f *fakeRepoForAdapter) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} +func (f *fakeRepoForAdapter) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error { + f.resetCalledWith = [4]any{userID, platform, window, newStart} + return f.resetErr +} +func (f *fakeRepoForAdapter) UpsertForUser(_ context.Context, userID int64, records []UserPlatformQuotaRecord) error { + f.upsertCalledUserID = userID + f.upsertCalledWith = records + return f.upsertErr +} + +func TestGenericAdapter_UpsertForUser_ForwardsRecords(t *testing.T) { + fake := &fakeRepoForAdapter{} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.UpsertForUser(context.Background(), 42, nil) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if fake.upsertCalledUserID != 42 { + t.Errorf("forwarded userID = %d, want 42", fake.upsertCalledUserID) + } +} + +func TestGenericAdapter_UpsertForUser_PropagatesError(t *testing.T) { + wantErr := errors.New("boom") + fake := &fakeRepoForAdapter{upsertErr: wantErr} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.UpsertForUser(context.Background(), 1, nil) + if !errors.Is(err, wantErr) { + t.Errorf("expected %v, got %v", wantErr, err) + } +} + +func TestGenericAdapter_ResetExpiredWindow_ForwardsAllParams(t *testing.T) { + fake := &fakeRepoForAdapter{} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + now := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) + if err := adapter.ResetExpiredWindow(context.Background(), 7, "openai", "weekly", now); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if fake.resetCalledWith[0].(int64) != 7 || + fake.resetCalledWith[1].(string) != "openai" || + fake.resetCalledWith[2].(string) != "weekly" || + !fake.resetCalledWith[3].(time.Time).Equal(now) { + t.Errorf("forwarded params mismatch: %+v", fake.resetCalledWith) + } +} + +func TestGenericAdapter_ResetExpiredWindow_PropagatesError(t *testing.T) { + wantErr := errors.New("not found") + fake := &fakeRepoForAdapter{resetErr: wantErr} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.ResetExpiredWindow(context.Background(), 1, "a", "daily", time.Now()) + if !errors.Is(err, wantErr) { + t.Errorf("expected %v, got %v", wantErr, err) + } +} diff --git a/backend/internal/repository/user_platform_quota_repo.go b/backend/internal/repository/user_platform_quota_repo.go new file mode 100644 index 00000000..1e2e7f51 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo.go @@ -0,0 +1,416 @@ +package repository + +import ( + "context" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" +) + +// UserPlatformQuotaRecord 是 repository 层的传输结构体, +// 与 ent.UserPlatformQuota 实体解耦,供业务层使用。 +type UserPlatformQuotaRecord struct { + UserID int64 + Platform string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + +// ErrUserPlatformQuotaNotFound 用于 ResetExpiredWindow 等需要"必须命中已有记录"的方法。 +var ErrUserPlatformQuotaNotFound = fmt.Errorf("user platform quota record not found") + +// UserPlatformQuotaRepository 定义用户平台配额的数据访问接口。 +type UserPlatformQuotaRepository interface { + // BulkInsertInitial 幂等批量插入初始配额记录(ON CONFLICT DO NOTHING)。 + BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error + // GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。 + GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) + // ListByUser 查询用户的所有平台配额记录(排除软删除)。 + ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) + // IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。 + IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error + // ResetExpiredWindow 重置指定窗口(daily/weekly/monthly)的用量与起始时间。 + ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error + // UpsertForUser 全量替换该用户所有平台限额配置(详见 service.UserPlatformQuotaRepository.UpsertForUser)。 + UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error +} + +type userPlatformQuotaRepository struct { + client *dbent.Client +} + +// NewUserPlatformQuotaRepository 创建 UserPlatformQuotaRepository 实现。 +func NewUserPlatformQuotaRepository(client *dbent.Client) UserPlatformQuotaRepository { + return &userPlatformQuotaRepository{client: client} +} + +// BulkInsertInitial 用原生 SQL ON CONFLICT 实现幂等批量插入(带条件 limit 覆盖)。 +// 仅插入 limit_usd 与元数据,usage_usd 用 DB 默认 0,window_start 留 NULL。 +// FK 约束要求 user_id 在 users 表中存在,调用方负责保证。 +// +// 冲突策略:CASE WHEN existing.*_limit_usd IS NULL THEN EXCLUDED.*_limit_usd ELSE existing ... +// - 若 IncrementUsageWithReset 因时序问题已先建行(limit 全 NULL), +// 此处会把注册时的默认 limit 写入,避免该用户在该平台永久无限额。 +// - 若管理员已通过 UpsertForUser 设置了非 NULL 个性化 limit,**保留不动** +// —— 旧实现无条件 EXCLUDED 覆盖会丢失个性化配置。 +// - 不会改 usage_usd / window_start,保留累计的用量。 +// - 仅命中 deleted_at IS NULL 的活跃记录(partial unique index 作用域)。 +func (r *userPlatformQuotaRepository) BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error { + if len(records) == 0 { + return nil + } + + client := clientFromContext(ctx, r.client) + + var sb strings.Builder + _, _ = sb.WriteString("INSERT INTO user_platform_quotas (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) VALUES ") + args := make([]any, 0, len(records)*6) + // 统一时间戳:避免循环内多次 time.Now() 让同一批记录的 created_at/updated_at + // 出现亚毫秒级偏差(与 UpsertForUser 的 now := time.Now() 风格一致)。 + now := time.Now() + for i, rec := range records { + base := i * 6 + if i > 0 { + _, _ = sb.WriteString(",") + } + fmt.Fprintf(&sb, "($%d,$%d,$%d,$%d,$%d,0,0,0,$%d,$%d)", + base+1, base+2, base+3, base+4, base+5, base+6, base+6) + args = append(args, + rec.UserID, rec.Platform, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, + now, + ) + } + // 精确命中 partial unique index(deleted_at IS NULL),避免对软删记录的歧义冲突。 + // 条件覆盖:仅在现有 limit 为 NULL 时才写入 EXCLUDED,否则保留现有非 NULL 值。 + // - 修复 IncrementUsageWithReset 已用 NULL limit 建行的场景(NULL → 注册默认) + // - 保护管理员通过 UpsertForUser 设置的个性化 limit 不被静默覆盖 + _, _ = sb.WriteString(` ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL + DO UPDATE SET + daily_limit_usd = COALESCE(user_platform_quotas.daily_limit_usd, EXCLUDED.daily_limit_usd), + weekly_limit_usd = COALESCE(user_platform_quotas.weekly_limit_usd, EXCLUDED.weekly_limit_usd), + monthly_limit_usd = COALESCE(user_platform_quotas.monthly_limit_usd, EXCLUDED.monthly_limit_usd), + updated_at = EXCLUDED.updated_at`) + + _, err := client.ExecContext(ctx, sb.String(), args...) + return err +} + +// GetByUserPlatform 通过 ent 查询单条配额(排除软删除)。未找到返回 (nil, nil)。 +func (r *userPlatformQuotaRepository) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) { + client := clientFromContext(ctx, r.client) + entity, err := client.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ). + Only(ctx) + if dbent.IsNotFound(err) { + return nil, nil + } + if err != nil { + return nil, err + } + return entQuotaToRecord(entity), nil +} + +// ListByUser 查询用户的所有平台配额记录(排除软删除)。 +func (r *userPlatformQuotaRepository) ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) { + client := clientFromContext(ctx, r.client) + rows, err := client.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.DeletedAtIsNil(), + ). + All(ctx) + if err != nil { + return nil, err + } + out := make([]UserPlatformQuotaRecord, 0, len(rows)) + for _, e := range rows { + out = append(out, *entQuotaToRecord(e)) + } + return out, nil +} + +// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的 *_usage_usd。 +// 行为: +// - 若记录存在:在事务内 SELECT FOR UPDATE,按 (prev_window_start vs current_window_start) +// 判断是否需要重置(不同 = 重置为 cost;相同 = 累加 cost) +// - 若记录不存在(fail-open create 分支):插入新记录,**limit 字段保留 nil(无限制)** +// —— 这是预期行为:billing 链路不能因 quota 表缺失而阻断请求,未注册路径 +// 的用户 quota 默认放行,由调度层指标观测 + 后台对账补建 limit +// +// 上层正常路径(注册时 BulkInsertInitial)保证 limit 在记录创建时就被写入。 +func (r *userPlatformQuotaRepository) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + existing, err := txClient.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ). + ForUpdate(). + Only(txCtx) + if dbent.IsNotFound(err) { + // fail-open 建行:limit_* 保留 NULL(无限额)。 + // 用 ON CONFLICT DO UPDATE 累加,而非裸 INSERT:并发下另一请求可能在本事务 + // SELECT FOR UPDATE 之后、INSERT 之前刚建行,裸 INSERT 会撞 partial unique index + // 致事务回滚、本次 cost 丢失;DO UPDATE 把 cost 累加到既有 usage 上。 + // 写法与本文件 insertLimitsRow / BulkInsertInitial 的 ON CONFLICT 一致。 + const insertSQL = `INSERT INTO user_platform_quotas + (user_id, platform, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, + daily_window_start, weekly_window_start, monthly_window_start, created_at, updated_at) + VALUES ($1, $2, $3, $3, $3, $4, $5, $6, $7, $7) + ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO UPDATE SET + daily_usage_usd = user_platform_quotas.daily_usage_usd + EXCLUDED.daily_usage_usd, + weekly_usage_usd = user_platform_quotas.weekly_usage_usd + EXCLUDED.weekly_usage_usd, + monthly_usage_usd = user_platform_quotas.monthly_usage_usd + EXCLUDED.monthly_usage_usd, + updated_at = EXCLUDED.updated_at` + // $6 = now:30 天滚动月度窗口以当前时刻为起始 + _, e := txClient.ExecContext(txCtx, insertSQL, + userID, platform, cost, + timezone.StartOfDay(now), timezone.StartOfWeek(now), now, now) + return e + } + if err != nil { + return err + } + + newDaily := maybeReset(existing.DailyUsageUsd, existing.DailyWindowStart, timezone.StartOfDay(now), cost) + newWeekly := maybeReset(existing.WeeklyUsageUsd, existing.WeeklyWindowStart, timezone.StartOfWeek(now), cost) + // 30 天滚动月度窗口:过期时重置为 cost 并以 now 为新起始,否则累加保留原起始 + newMonthly, newMonthlyStart := monthlyMaybeReset(existing.MonthlyUsageUsd, existing.MonthlyWindowStart, cost, now) + + _, e := existing.Update(). + SetDailyUsageUsd(newDaily). + SetWeeklyUsageUsd(newWeekly). + SetMonthlyUsageUsd(newMonthly). + SetDailyWindowStart(timezone.StartOfDay(now)). + SetWeeklyWindowStart(timezone.StartOfWeek(now)). + SetMonthlyWindowStart(newMonthlyStart). // 30 天滚动:仅过期时更新起始 + Save(txCtx) + return e + }) +} + +// ResetExpiredWindow 无条件重置指定窗口(daily/weekly/monthly)的用量与起始时间。 +// +// ⚠️ 命名警告(NOT a "check-then-reset" helper): +// +// 名字里的 "Expired" 是历史遗留,**实现并不校验窗口是否真的过期**。 +// 任何调用都会无条件把对应窗口的 *_usage_usd 清零并重写 *_window_start。 +// 目前唯一合法 caller 是 admin POST /reset 接口(管理员强制归零)。 +// +// 如果你想要"仅在窗口过期才重置"的语义,请直接使用 IncrementUsageWithReset +// 的内部判断(maybeReset / monthlyMaybeReset),或新增独立函数; +// 不要复用这里的函数,否则会出现"明明窗口未过期,用量却被清零"的隐蔽 bug。 +// +// 未命中活跃记录时返回 ErrUserPlatformQuotaNotFound。 +func (r *userPlatformQuotaRepository) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + client := clientFromContext(ctx, r.client) + upd := client.UserPlatformQuota.Update(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ) + switch window { + case "daily": + upd = upd.SetDailyUsageUsd(0).SetDailyWindowStart(newStart) + case "weekly": + upd = upd.SetWeeklyUsageUsd(0).SetWeeklyWindowStart(newStart) + case "monthly": + upd = upd.SetMonthlyUsageUsd(0).SetMonthlyWindowStart(newStart) + default: + return fmt.Errorf("unknown window %q", window) + } + n, err := upd.Save(ctx) + if err != nil { + return err + } + if n == 0 { + return ErrUserPlatformQuotaNotFound + } + return nil +} + +// withTx 在事务中执行 fn,若 ctx 中已有事务则复用。 +func (r *userPlatformQuotaRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx, tx.Client()) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return fmt.Errorf("begin user_platform_quota transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx, tx.Client()); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit user_platform_quota transaction: %w", err) + } + return nil +} + +// entQuotaToRecord 将 ent entity 映射为 repository record。 +// 注意 ent 生成字段名为 DailyLimitUsd(非 DailyLimitUSD)。 +func entQuotaToRecord(e *dbent.UserPlatformQuota) *UserPlatformQuotaRecord { + return &UserPlatformQuotaRecord{ + UserID: e.UserID, + Platform: e.Platform, + DailyLimitUSD: e.DailyLimitUsd, + WeeklyLimitUSD: e.WeeklyLimitUsd, + MonthlyLimitUSD: e.MonthlyLimitUsd, + DailyUsageUSD: e.DailyUsageUsd, + WeeklyUsageUSD: e.WeeklyUsageUsd, + MonthlyUsageUSD: e.MonthlyUsageUsd, + DailyWindowStart: e.DailyWindowStart, + WeeklyWindowStart: e.WeeklyWindowStart, + MonthlyWindowStart: e.MonthlyWindowStart, + } +} + +// maybeReset 判断是否需要重置窗口用量: +// - 若 prevStart 为 nil 或与 currStart 不同,表示窗口已过期,返回 cost(重置) +// - 否则返回 prevUsage + cost(累加) +func maybeReset(prevUsage float64, prevStart *time.Time, currStart time.Time, cost float64) float64 { + if prevStart == nil || !prevStart.Equal(currStart) { + return cost + } + return prevUsage + cost +} + +// monthlyMaybeReset 判断 30 天滚动月度窗口是否需要重置。 +// 过期条件:prevStart 为 nil 或 now - prevStart >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。 +// 过期时重置为 cost,否则累加。返回 (newUsage, newWindowStart)。 +func monthlyMaybeReset(prevUsage float64, prevStart *time.Time, cost float64, now time.Time) (float64, time.Time) { + if prevStart == nil || now.Sub(*prevStart) >= 30*24*time.Hour { + return cost, now + } + return prevUsage + cost, *prevStart +} + +// UpsertForUser 全量替换该用户的所有平台限额(事务内): +// 1. 软删除未在 records 中出现的所有 active 行 +// 2. 对每条 record 尝试 UPDATE(含 deleted_at = NULL 兼容重激活); +// UPDATE 行数为 0 时 INSERT 新行 +// +// 仅改 *_limit_usd + deleted_at + updated_at,保留 *_usage_usd / *_window_start。 +func (r *userPlatformQuotaRepository) UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error { + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + platforms := make([]string, 0, len(records)) + for _, rec := range records { + platforms = append(platforms, rec.Platform) + } + now := time.Now() + if err := softDeleteMissingPlatforms(txCtx, txClient, userID, platforms, now); err != nil { + return err + } + for _, rec := range records { + affected, err := updateLimitsRow(txCtx, txClient, userID, rec, now) + if err != nil { + return err + } + if affected == 0 { + if err := insertLimitsRow(txCtx, txClient, userID, rec, now); err != nil { + return err + } + } + } + return nil + }) +} + +// softDeleteMissingPlatforms 软删除该用户所有不在 keepPlatforms 中的 active 行。 +// keepPlatforms 为空时 → 软删用户所有 active 行。 +// now 由调用方传入,与 updateLimitsRow / insertLimitsRow 共享同一个 Go time.Now(), +// 保证事务内所有时间戳一致(避免 Postgres NOW() 与 Go time.Now() 的微小偏差)。 +func softDeleteMissingPlatforms(ctx context.Context, client *dbent.Client, userID int64, keepPlatforms []string, now time.Time) error { + var ( + query string + args []any + ) + if len(keepPlatforms) == 0 { + query = `UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2 + WHERE user_id = $1 AND deleted_at IS NULL` + args = []any{userID, now} + } else { + placeholders := make([]string, len(keepPlatforms)) + args = make([]any, 0, len(keepPlatforms)+2) + args = append(args, userID, now) + for i, p := range keepPlatforms { + placeholders[i] = fmt.Sprintf("$%d", i+3) + args = append(args, p) + } + query = fmt.Sprintf(`UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2 + WHERE user_id = $1 AND deleted_at IS NULL AND platform NOT IN (%s)`, + strings.Join(placeholders, ",")) + } + _, err := client.ExecContext(ctx, query, args...) + return err +} + +// updateLimitsRow 尝试 UPDATE active 行(deleted_at IS NULL),返回受影响行数。 +// 仅更新 active 行:若存在多条历史软删记录,加 deleted_at IS NULL 守卫可避免 +// 批量重激活导致的 partial unique index(userplatformquota_user_id_platform_uq)冲突。 +// affected=0 时由调用方 UpsertForUser 走 insertLimitsRow 路径创建新行。 +func updateLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) (int64, error) { + const query = `UPDATE user_platform_quotas + SET daily_limit_usd = $1, weekly_limit_usd = $2, monthly_limit_usd = $3, + deleted_at = NULL, updated_at = $4 + WHERE user_id = $5 AND platform = $6 AND deleted_at IS NULL` + res, err := client.ExecContext(ctx, query, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, now, + userID, rec.Platform) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// insertLimitsRow 插入新限额行(usage 默认 0,window_start 默认 NULL)。 +// 带 ON CONFLICT ... DO NOTHING 守卫:防止两个并发请求同时为同一 user/platform 新建行时 +// 触发 unique constraint 违反(userplatformquota_user_id_platform_uq 部分唯一索引)。 +// affected=0 时说明另一个并发请求刚完成 INSERT,fallback 到 updateLimitsRow 覆写 limits 值。 +func insertLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) error { + const query = `INSERT INTO user_platform_quotas + (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, + daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, 0, 0, 0, $6, $6) + ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO NOTHING` + res, err := client.ExecContext(ctx, query, + userID, rec.Platform, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, + now) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + // 并发情形:另一请求已插入该行,fallback 到 UPDATE 覆写 limits 值(last-writer-wins)。 + _, err = updateLimitsRow(ctx, client, userID, rec, now) + return err + } + return nil +} diff --git a/backend/internal/repository/user_platform_quota_repo_integration_test.go b/backend/internal/repository/user_platform_quota_repo_integration_test.go new file mode 100644 index 00000000..f02eeaa9 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo_integration_test.go @@ -0,0 +1,269 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// mustCreateUserForQuota 在指定 client 上创建测试用户(满足 FK 约束)。 +func mustCreateUserForQuota(t *testing.T, client *dbent.Client) int64 { + t.Helper() + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("quota-test-%d@example.com", time.Now().UnixNano()), + }) + return u.ID +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_Idempotent(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + daily := 5.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily}, + {UserID: userID, Platform: "openai"}, + } + + // 第一次插入 + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "first insert") + // 第二次插入应为 no-op(ON CONFLICT DO NOTHING) + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "second insert (idempotent)") + + list, err := repo.ListByUser(txCtx, userID) + require.NoError(t, err, "list") + require.Len(t, list, 2, "expected 2 records after idempotent insert") + + // 校验 daily_limit_usd 保留 + var anthropicRec *UserPlatformQuotaRecord + for i := range list { + if list[i].Platform == "anthropic" { + anthropicRec = &list[i] + } + } + require.NotNil(t, anthropicRec, "anthropic record should exist") + require.NotNil(t, anthropicRec.DailyLimitUSD, "daily limit should be set") + require.InDelta(t, 5.0, *anthropicRec.DailyLimitUSD, 1e-9, "daily limit should be 5.0") +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_Empty(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewUserPlatformQuotaRepository(client) + // 空切片不应报错 + require.NoError(t, repo.BulkInsertInitial(txCtx, nil)) + require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{})) +} + +func TestUserPlatformQuotaRepository_GetByUserPlatform(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 未插入时应返回 nil + rec, err := repo.GetByUserPlatform(txCtx, userID, "anthropic") + require.NoError(t, err, "get before insert should not error") + require.Nil(t, rec, "get before insert should return nil") + + // 插入后查询 + daily := 10.0 + require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily}, + })) + + rec, err = repo.GetByUserPlatform(txCtx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.Equal(t, userID, rec.UserID) + require.Equal(t, "anthropic", rec.Platform) + require.NotNil(t, rec.DailyLimitUSD) + require.InDelta(t, 10.0, *rec.DailyLimitUSD, 1e-9) +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_SameWindow(t *testing.T) { + ctx := context.Background() + + // IncrementUsageWithReset 内部自己开事务,使用独立 ent client 确保跨事务可见 + client := testEntClient(t) + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五 + + // 首次调用:应新建记录 + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.5, now)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.InDelta(t, 1.5, rec.DailyUsageUSD, 1e-9, "initial daily usage") + require.InDelta(t, 1.5, rec.WeeklyUsageUSD, 1e-9, "initial weekly usage") + require.InDelta(t, 1.5, rec.MonthlyUsageUSD, 1e-9, "initial monthly usage") + + // 同一天再次调用:应累加 + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 0.5, now)) + + rec, err = repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "accumulated daily usage") + require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "accumulated weekly usage") + require.InDelta(t, 2.0, rec.MonthlyUsageUSD, 1e-9, "accumulated monthly usage") +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_DailyReset(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + day1 := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五(同一周、同一月) + day2 := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) // 周六(同一周、同一月) + + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.0, day1)) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.0, day2)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.InDelta(t, 1.0, rec.DailyUsageUSD, 1e-9, "daily should reset to 1.0") + require.InDelta(t, 4.0, rec.WeeklyUsageUSD, 1e-9, "weekly should accumulate to 4.0 (same week)") + require.InDelta(t, 4.0, rec.MonthlyUsageUSD, 1e-9, "monthly should accumulate to 4.0 (same month)") +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_WeeklyReset(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 5月22日(周五)和 5月25日(下周一),不同周 + fri := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) + nextMon := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) // 下一周周一 + + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 5.0, fri)) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 2.0, nextMon)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "openai") + require.NoError(t, err) + require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "daily resets to new cost") + require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "weekly resets (new week)") + require.InDelta(t, 7.0, rec.MonthlyUsageUSD, 1e-9, "monthly accumulates (same month)") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 先通过 ent 直接建一条记录 + _, err := client.UserPlatformQuota.Create(). + SetUserID(userID). + SetPlatform("gemini"). + SetDailyUsageUsd(10.0). + SetWeeklyUsageUsd(20.0). + SetMonthlyUsageUsd(50.0). + SetDailyWindowStart(time.Date(2026, 5, 21, 0, 0, 0, 0, time.UTC)). + SetWeeklyWindowStart(time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)). + SetMonthlyWindowStart(time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)). + Save(txCtx) + require.NoError(t, err) + + newStart := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + require.NoError(t, repo.ResetExpiredWindow(txCtx, userID, "gemini", "daily", newStart)) + + rec, err := repo.GetByUserPlatform(txCtx, userID, "gemini") + require.NoError(t, err) + require.InDelta(t, 0.0, rec.DailyUsageUSD, 1e-9, "daily usage reset to 0") + require.NotNil(t, rec.DailyWindowStart) + require.True(t, rec.DailyWindowStart.Equal(newStart), "daily window start updated") + // 其他窗口不变 + require.InDelta(t, 20.0, rec.WeeklyUsageUSD, 1e-9, "weekly usage unchanged") + require.InDelta(t, 50.0, rec.MonthlyUsageUSD, 1e-9, "monthly usage unchanged") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow_UnknownWindow(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + repo := NewUserPlatformQuotaRepository(client) + err := repo.ResetExpiredWindow(ctx, 999, "anthropic", "yearly", time.Now()) + require.Error(t, err, "unknown window should return error") +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_MultiRow(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d1, d2, d3 := 5.0, 10.0, 15.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d2}, + {UserID: userID, Platform: "gemini", DailyLimitUSD: &d3}, + } + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "multi-row insert failed") + + list, err := repo.ListByUser(txCtx, userID) + require.NoError(t, err) + require.Len(t, list, 3, "expected 3 rows, got %d", len(list)) + + // 验证 limit 值与传入一致(防占位符串位) + byPlatform := map[string]*UserPlatformQuotaRecord{} + for i := range list { + byPlatform[list[i].Platform] = &list[i] + } + require.NotNil(t, byPlatform["anthropic"], "anthropic record should exist") + require.NotNil(t, byPlatform["anthropic"].DailyLimitUSD, "anthropic daily limit should be set") + require.InDelta(t, 5.0, *byPlatform["anthropic"].DailyLimitUSD, 1e-9, "anthropic daily_limit = want 5.0") + + require.NotNil(t, byPlatform["openai"], "openai record should exist") + require.NotNil(t, byPlatform["openai"].DailyLimitUSD, "openai daily limit should be set") + require.InDelta(t, 10.0, *byPlatform["openai"].DailyLimitUSD, 1e-9, "openai daily_limit = want 10.0") + + require.NotNil(t, byPlatform["gemini"], "gemini record should exist") + require.NotNil(t, byPlatform["gemini"].DailyLimitUSD, "gemini daily limit should be set") + require.InDelta(t, 15.0, *byPlatform["gemini"].DailyLimitUSD, 1e-9, "gemini daily_limit = want 15.0") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow_NotFoundReturnsSentinel(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUserPlatformQuotaRepository(client) + + err := repo.ResetExpiredWindow(ctx, 99999, "anthropic", "daily", time.Now()) + require.True(t, errors.Is(err, ErrUserPlatformQuotaNotFound), + "expected ErrUserPlatformQuotaNotFound, got %v", err) +} diff --git a/backend/internal/repository/user_platform_quota_repo_test.go b/backend/internal/repository/user_platform_quota_repo_test.go new file mode 100644 index 00000000..84377353 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo_test.go @@ -0,0 +1,103 @@ +//go:build unit + +package repository + +import ( + "os" + "strings" + "testing" + "time" +) + +func TestMaybeReset(t *testing.T) { + start := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + other := start.AddDate(0, 0, -1) + cases := []struct { + name string + prevUsage float64 + prevStart *time.Time + currStart time.Time + cost float64 + want float64 + }{ + {"nil prev start resets", 10, nil, start, 1.5, 1.5}, + {"different start resets", 10, &other, start, 1.5, 1.5}, + {"same start accumulates", 10, &start, start, 1.5, 11.5}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := maybeReset(c.prevUsage, c.prevStart, c.currStart, c.cost); got != c.want { + t.Errorf("maybeReset = %v, want %v", got, c.want) + } + }) + } +} + +// TestMonthlyMaybeReset_NilStart 验证 prevStart=nil 时重置。 +func TestMonthlyMaybeReset_NilStart(t *testing.T) { + now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + usage, start := monthlyMaybeReset(10.0, nil, 1.5, now) + if usage != 1.5 { + t.Errorf("usage = %v, want 1.5", usage) + } + if !start.Equal(now) { + t.Errorf("start = %v, want %v", start, now) + } +} + +// TestMonthlyMaybeReset_Expired 验证窗口满 30 天时重置(30 天恰好到期)。 +func TestMonthlyMaybeReset_Expired(t *testing.T) { + windowStart := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC) + // now = windowStart + 30d(刚好到期) + now := windowStart.Add(30 * 24 * time.Hour) + usage, start := monthlyMaybeReset(8.0, &windowStart, 2.0, now) + if usage != 2.0 { + t.Errorf("usage = %v, want 2.0 (reset)", usage) + } + if !start.Equal(now) { + t.Errorf("start = %v, want %v (new window)", start, now) + } +} + +// TestMonthlyMaybeReset_CrossMonthBoundary 验证跨自然月时也使用 30 天滚动(不提前重置)。 +// 旧行为:5 月 1 日跨月立即重置;新行为:窗口起始 4 月 20 日,5 月 1 日仅过了 11 天,应累加。 +func TestMonthlyMaybeReset_CrossMonthBoundary(t *testing.T) { + windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + // 5 月 1 日:距起始 11 天,不足 30 天,应累加 + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + usage, start := monthlyMaybeReset(5.0, &windowStart, 1.0, now) + if usage != 6.0 { + t.Errorf("usage = %v, want 6.0 (accumulate, not reset at month boundary)", usage) + } + if !start.Equal(windowStart) { + t.Errorf("start = %v, want %v (preserved)", start, windowStart) + } +} + +// TestMonthlyMaybeReset_Active 验证窗口内正常累加。 +func TestMonthlyMaybeReset_Active(t *testing.T) { + windowStart := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + // 15 天内,窗口有效 + now := windowStart.Add(15 * 24 * time.Hour) + usage, start := monthlyMaybeReset(3.0, &windowStart, 0.5, now) + if usage != 3.5 { + t.Errorf("usage = %v, want 3.5", usage) + } + if !start.Equal(windowStart) { + t.Errorf("start = %v, want %v", start, windowStart) + } +} + +// TestUpdateLimitsRowQuery_HasDeletedAtGuard 通过读取源文件验证 updateLimitsRow +// 的 SQL WHERE 子句包含 deleted_at IS NULL 守卫(I-NEW-1)。 +// 此防回归测试可在无 DB 的 CI 环境中运行,防止意外删除该守卫。 +func TestUpdateLimitsRowQuery_HasDeletedAtGuard(t *testing.T) { + src, err := os.ReadFile("user_platform_quota_repo.go") + if err != nil { + t.Fatalf("failed to read source file: %v", err) + } + const guard = "AND deleted_at IS NULL" + if !strings.Contains(string(src), guard) { + t.Errorf("updateLimitsRow SQL must contain %q to prevent bulk reactivation of soft-deleted rows (I-NEW-1)", guard) + } +} diff --git a/backend/internal/repository/user_platform_quota_service_adapter.go b/backend/internal/repository/user_platform_quota_service_adapter.go new file mode 100644 index 00000000..7495cd26 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_service_adapter.go @@ -0,0 +1,206 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// userPlatformQuotaServiceAdapter 将 repository 层的 userPlatformQuotaRepository +// 适配为 service.UserPlatformQuotaRepository 接口(返回 *service.UserPlatformQuotaRecord)。 +type userPlatformQuotaServiceAdapter struct { + inner *userPlatformQuotaRepository +} + +// NewUserPlatformQuotaServiceAdapter 将 UserPlatformQuotaRepository 实现包装为 +// 满足 service.UserPlatformQuotaRepository 接口的适配器。 +func NewUserPlatformQuotaServiceAdapter(repo UserPlatformQuotaRepository) service.UserPlatformQuotaRepository { + impl, ok := repo.(*userPlatformQuotaRepository) + if !ok { + // 非标准实现(如测试 fake),通过通用适配器包装 + return &genericUserPlatformQuotaAdapter{inner: repo} + } + return &userPlatformQuotaServiceAdapter{inner: impl} +} + +func (a *userPlatformQuotaServiceAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) { + rec, err := a.inner.GetByUserPlatform(ctx, userID, platform) + if err != nil || rec == nil { + return nil, err + } + return toServiceRecord(rec), nil +} + +// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的用量。 +func (a *userPlatformQuotaServiceAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now) +} + +// ListByUser 查询用户的所有平台配额记录。 +func (a *userPlatformQuotaServiceAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) { + rows, err := a.inner.ListByUser(ctx, userID) + if err != nil { + return nil, err + } + out := make([]service.UserPlatformQuotaRecord, len(rows)) + for i, r := range rows { + out[i] = service.UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out, nil +} + +// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 repo。 +func (a *userPlatformQuotaServiceAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error { + repoRecords := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + repoRecords[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + } + } + return a.inner.BulkInsertInitial(ctx, repoRecords) +} + +// UpsertForUser 全量替换该用户所有平台限额。 +func (a *userPlatformQuotaServiceAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + repoRecords := toRepoRecords(records) + return a.inner.UpsertForUser(ctx, userID, repoRecords) +} + +// ResetExpiredWindow 转发至 repository.ResetExpiredWindow,并将 repository sentinel 包装为 service sentinel。 +func (a *userPlatformQuotaServiceAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart) + if errors.Is(err, ErrUserPlatformQuotaNotFound) { + return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err) + } + return err +} + +// genericUserPlatformQuotaAdapter 通过通用接口适配(用于测试 fake 或非标准实现)。 +type genericUserPlatformQuotaAdapter struct { + inner UserPlatformQuotaRepository +} + +func (a *genericUserPlatformQuotaAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) { + rec, err := a.inner.GetByUserPlatform(ctx, userID, platform) + if err != nil || rec == nil { + return nil, err + } + return toServiceRecord(rec), nil +} + +// IncrementUsageWithReset 原子累加 cost(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now) +} + +// ListByUser 查询用户的所有平台配额记录(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) { + rows, err := a.inner.ListByUser(ctx, userID) + if err != nil { + return nil, err + } + out := make([]service.UserPlatformQuotaRecord, len(rows)) + for i, r := range rows { + out[i] = service.UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out, nil +} + +// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 generic repo。 +func (a *genericUserPlatformQuotaAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error { + repoRecords := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + repoRecords[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + } + } + return a.inner.BulkInsertInitial(ctx, repoRecords) +} + +// UpsertForUser 全量替换(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + repoRecords := toRepoRecords(records) + return a.inner.UpsertForUser(ctx, userID, repoRecords) +} + +// ResetExpiredWindow 转发至 repository.ResetExpiredWindow(通用 adapter),并包装 sentinel。 +func (a *genericUserPlatformQuotaAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart) + if errors.Is(err, ErrUserPlatformQuotaNotFound) { + return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err) + } + return err +} + +// toServiceRecord 将 repository.UserPlatformQuotaRecord 转换为 service.UserPlatformQuotaRecord。 +func toServiceRecord(rec *UserPlatformQuotaRecord) *service.UserPlatformQuotaRecord { + return &service.UserPlatformQuotaRecord{ + UserID: rec.UserID, + Platform: rec.Platform, + DailyLimitUSD: rec.DailyLimitUSD, + WeeklyLimitUSD: rec.WeeklyLimitUSD, + MonthlyLimitUSD: rec.MonthlyLimitUSD, + DailyUsageUSD: rec.DailyUsageUSD, + WeeklyUsageUSD: rec.WeeklyUsageUSD, + MonthlyUsageUSD: rec.MonthlyUsageUSD, + DailyWindowStart: rec.DailyWindowStart, + WeeklyWindowStart: rec.WeeklyWindowStart, + MonthlyWindowStart: rec.MonthlyWindowStart, + } +} + +// toRepoRecords 将 service.UserPlatformQuotaRecord 切片转换为 repository.UserPlatformQuotaRecord(含 limit 字段,含 usage/window_start)。 +func toRepoRecords(records []service.UserPlatformQuotaRecord) []UserPlatformQuotaRecord { + out := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + out[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out +} diff --git a/backend/internal/repository/user_platform_quota_upsert_test.go b/backend/internal/repository/user_platform_quota_upsert_test.go new file mode 100644 index 00000000..db8e6f47 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_upsert_test.go @@ -0,0 +1,148 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" + "github.com/stretchr/testify/require" +) + +func TestUpsertForUser_NewUserInsertsAllRecords(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + daily := 10.0 + weekly := 50.0 + monthly := 200.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily, WeeklyLimitUSD: &weekly, MonthlyLimitUSD: &monthly}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &daily}, + } + require.NoError(t, repo.UpsertForUser(ctx, userID, records)) + + got, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + require.Len(t, got, 2) +} + +func TestUpsertForUser_PartialUpdateSoftDeletesMissingPlatforms(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d1 := 10.0 + d2 := 20.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d1}, + })) + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2}, + {UserID: userID, Platform: "gemini", DailyLimitUSD: &d1}, + })) + + active, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + platforms := map[string]float64{} + for _, r := range active { + require.NotNil(t, r.DailyLimitUSD) + platforms[r.Platform] = *r.DailyLimitUSD + } + require.Len(t, platforms, 2) + require.InDelta(t, 20.0, platforms["anthropic"], 1e-9) + require.InDelta(t, 10.0, platforms["gemini"], 1e-9) + _, openaiActive := platforms["openai"] + require.False(t, openaiActive, "openai should be soft-deleted") +} + +func TestUpsertForUser_PreservesUsageAndWindowStart(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + })) + + now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.5, now)) + + newD := 50.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &newD}, + })) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.InDelta(t, 50.0, *rec.DailyLimitUSD, 1e-9, "limit should update") + require.InDelta(t, 3.5, rec.DailyUsageUSD, 1e-9, "usage must be preserved") + require.NotNil(t, rec.DailyWindowStart, "window_start must be preserved") +} + +func TestUpsertForUser_ReactivatesSoftDeleted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + })) + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{})) + + gone, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.Nil(t, gone, "anthropic should be soft-deleted (not active)") + + d2 := 20.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2}, + })) + + back, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, back, "anthropic should be active again") + require.InDelta(t, 20.0, *back.DailyLimitUSD, 1e-9) + + allRows, err := client.UserPlatformQuota.Query(). + Where(userplatformquota.UserIDEQ(userID), userplatformquota.PlatformEQ("anthropic")). + All(ctx) + require.NoError(t, err) + activeCount := 0 + for _, r := range allRows { + if r.DeletedAt == nil { + activeCount++ + } + } + require.Equal(t, 1, activeCount, "should have exactly one active row") +} + +func TestUpsertForUser_EmptyClearsAll(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d}, + })) + + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{})) + + got, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + require.Empty(t, got) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3c0ee9cb..2d1b04e3 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -93,6 +93,8 @@ var ProviderSet = wire.NewSet( NewChannelMonitorRequestTemplateRepository, NewContentModerationRepository, NewAffiliateRepository, + NewUserPlatformQuotaRepository, // T14: user × platform quota + NewUserPlatformQuotaServiceAdapter, // T14: adapter → service.UserPlatformQuotaRepository // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 8dd9e6fb..6c23195b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -798,6 +798,14 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}}, + "auth_source_default_email_platform_quotas": null, + "auth_source_default_github_platform_quotas": null, + "auth_source_default_google_platform_quotas": null, + "auth_source_default_linuxdo_platform_quotas": null, + "auth_source_default_oidc_platform_quotas": null, + "auth_source_default_wechat_platform_quotas": null, + "auth_source_default_dingtalk_platform_quotas": null, "affiliate_rebate_rate": 20, "affiliate_rebate_freeze_hours": 0, "affiliate_rebate_duration_days": 0, @@ -1025,6 +1033,14 @@ func TestAPIContracts(t *testing.T) { "purchase_subscription_url": "", "table_default_page_size": 20, "table_page_size_options": [10, 20, 50], + "default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}}, + "auth_source_default_email_platform_quotas": null, + "auth_source_default_github_platform_quotas": null, + "auth_source_default_google_platform_quotas": null, + "auth_source_default_linuxdo_platform_quotas": null, + "auth_source_default_oidc_platform_quotas": null, + "auth_source_default_wechat_platform_quotas": null, + "auth_source_default_dingtalk_platform_quotas": null, "custom_menu_items": [], "custom_endpoints": [], "default_concurrency": 0, @@ -1718,7 +1734,7 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt return errors.New("not implemented") } -func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { +func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { return errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 3fbbb716..303d0db8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 7b9a1ee0..d33ccbf5 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -263,6 +263,7 @@ func abortIfAPIKeyGroupUnavailable(c *gin.Context, apiKey *service.APIKey) bool if ok { return false } + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnavailable) AbortWithError(c, 403, code, message) return true } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 3ed71f71..596bed52 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -55,6 +55,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs return } if _, message, ok := validateAPIKeyGroupAvailable(apiKey); !ok { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnavailable) abortWithGoogleError(c, 403, message) return } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index f8e50fcd..feadd27d 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -373,6 +373,68 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) } +func TestApiKeyAuthWithSubscriptionGoogle_MarksUnavailableGroupBusinessLimited(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(101) + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + GroupID: &groupID, + Key: "google-group-deleted", + Status: service.StatusActive, + User: user, + Group: &service.Group{ + ID: groupID, + Name: "deleted", + Status: "deleted", + Platform: service.PlatformGemini, + Hydrated: true, + }, + } + + r := gin.New() + var markedBusinessLimited bool + var businessLimitedReason string + r.Use(func(c *gin.Context) { + c.Next() + markedBusinessLimited = service.HasOpsClientBusinessLimited(c) + if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok { + businessLimitedReason, _ = v.(string) + } + }) + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }) + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{RunMode: config.RunModeSimple})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusForbidden, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "API Key 所属分组已删除", resp.Error.Message) + require.True(t, markedBusinessLimited) + require.Equal(t, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnavailable, businessLimitedReason) +} + func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 57e69f10..76a24192 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -317,6 +317,7 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { group *service.Group wantStatus int wantCode string + wantMarked bool }{ { name: "active group passes", @@ -340,6 +341,7 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { }, wantStatus: http.StatusForbidden, wantCode: "GROUP_DISABLED", + wantMarked: true, }, { name: "deleted status group is forbidden", @@ -352,12 +354,14 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { }, wantStatus: http.StatusForbidden, wantCode: "GROUP_DELETED", + wantMarked: true, }, { name: "missing group edge is forbidden", group: nil, wantStatus: http.StatusForbidden, wantCode: "GROUP_DELETED", + wantMarked: true, }, } @@ -383,7 +387,20 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeStandard} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - router := newAuthTestRouter(apiKeyService, nil, cfg) + router := gin.New() + var markedBusinessLimited bool + var businessLimitedReason string + router.Use(func(c *gin.Context) { + c.Next() + markedBusinessLimited = service.HasOpsClientBusinessLimited(c) + if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok { + businessLimitedReason, _ = v.(string) + } + }) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/t", nil) @@ -394,10 +411,57 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { if tt.wantCode != "" { require.Contains(t, w.Body.String(), tt.wantCode) } + require.Equal(t, tt.wantMarked, markedBusinessLimited) + if tt.wantMarked { + require.Equal(t, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnavailable, businessLimitedReason) + } }) } } +func TestRequireGroupAssignmentMarksUngroupedKeyBusinessLimited(t *testing.T) { + gin.SetMode(gin.TestMode) + + settingService := service.NewSettingService(fakeSettingRepo{ + values: map[string]string{ + service.SettingKeyAllowUngroupedKeyScheduling: "false", + }, + }, &config.Config{}) + apiKey := &service.APIKey{ + ID: 100, + Key: "ungrouped-key", + Status: service.StatusActive, + } + + router := gin.New() + var markedBusinessLimited bool + var businessLimitedReason string + router.Use(func(c *gin.Context) { + c.Next() + markedBusinessLimited = service.HasOpsClientBusinessLimited(c) + if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok { + businessLimitedReason, _ = v.(string) + } + }) + router.Use(func(c *gin.Context) { + c.Set(string(ContextKeyAPIKey), apiKey) + c.Next() + }) + router.Use(RequireGroupAssignment(settingService, AnthropicErrorWriter)) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "not assigned to any group") + require.True(t, markedBusinessLimited) + require.Equal(t, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnassigned, businessLimitedReason) +} + func TestAPIKeyAuthIPRestrictionDoesNotTrustForwardedClientIPByDefault(t *testing.T) { gin.SetMode(gin.TestMode) @@ -771,6 +835,41 @@ type stubUserSubscriptionRepo struct { resetMonthly func(ctx context.Context, id int64, start time.Time) error } +type fakeSettingRepo struct { + values map[string]string +} + +func (r fakeSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + return nil, errors.New("not implemented") +} + +func (r fakeSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (r fakeSettingRepo) Set(ctx context.Context, key, value string) error { + return errors.New("not implemented") +} + +func (r fakeSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + return nil, errors.New("not implemented") +} + +func (r fakeSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + return errors.New("not implemented") +} + +func (r fakeSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + return nil, errors.New("not implemented") +} + +func (r fakeSettingRepo) Delete(ctx context.Context, key string) error { + return errors.New("not implemented") +} + func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") } diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index a643d3bc..e9922e54 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) @@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) toucher := &recordingActivityToucher{} diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index 27985cf8..d42eacec 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -115,6 +115,7 @@ func RequireGroupAssignment(settingService *service.SettingService, writeError G c.Next() return } + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonAPIKeyGroupUnassigned) writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.") c.Abort() } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 738ae2c4..043faf5e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -244,6 +244,9 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency) + users.GET("/:id/platform-quotas", h.Admin.User.GetUserPlatformQuotas) + users.PUT("/:id/platform-quotas", h.Admin.User.UpdateUserPlatformQuotas) + users.POST("/:id/platform-quotas/reset", h.Admin.User.ResetUserPlatformQuotaWindow) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -259,6 +262,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary) groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) + groups.GET("/:id/models-list-candidates", h.Admin.Group.GetModelsListCandidates) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) groups.PUT("/:id", h.Admin.Group.Update) @@ -288,6 +292,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/test", h.Admin.Account.Test) accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState) accounts.POST("/:id/refresh", h.Admin.Account.Refresh) + accounts.POST("/:id/apply-oauth-credentials", h.Admin.Account.ApplyOAuthCredentials) accounts.POST("/:id/set-privacy", h.Admin.Account.SetPrivacy) accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier) accounts.GET("/:id/stats", h.Admin.Account.GetStats) diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 9541cda1..b039a6ec 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -51,6 +51,7 @@ func RegisterGatewayRoutes( // /v1/messages/count_tokens: OpenAI groups get 404 gateway.POST("/messages/count_tokens", func(c *gin.Context) { if getGroupPlatform(c) == service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusNotFound, gin.H{ "type": "error", "error": gin.H{ @@ -88,8 +89,22 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + gateway.POST("/embeddings", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Embeddings API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Embeddings(c) + }) gateway.POST("/images/generations", func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ "type": "not_found_error", @@ -102,6 +117,7 @@ func RegisterGatewayRoutes( }) gateway.POST("/images/edits", func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ "type": "not_found_error", @@ -155,8 +171,22 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + r.POST("/embeddings", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Embeddings API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Embeddings(c) + }) r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ "type": "not_found_error", @@ -169,6 +199,7 @@ func RegisterGatewayRoutes( }) r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { if getGroupPlatform(c) != service.PlatformOpenAI { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusNotFound, gin.H{ "error": gin.H{ "type": "not_found_error", diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index e79d3ee3..07ae33de 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -32,6 +32,7 @@ func RegisterUserRoutes( user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage) + user.GET("/platform-quotas", h.User.GetMyPlatformQuotas) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 98fca03d..91d7eeb6 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -910,14 +910,90 @@ func parsePoolModeRetryCount(value any) int { return defaultPoolModeRetryCount } -// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码 +// defaultPoolModeRetryableStatusCodes 池模式下默认触发同账号重试的状态码。 +// 未在 Account.Credentials 中显式配置 pool_mode_retry_status_codes 时使用。 +var defaultPoolModeRetryableStatusCodes = []int{401, 403, 429} + +// isPoolModeRetryableStatus 池模式下应触发同账号重试的状态码(默认列表)。 func isPoolModeRetryableStatus(statusCode int) bool { - switch statusCode { - case 401, 403, 429: - return true - default: - return false + for _, c := range defaultPoolModeRetryableStatusCodes { + if c == statusCode { + return true + } } + return false +} + +// GetPoolModeRetryStatusCodes 返回账号自定义的池模式同账号重试状态码列表。 +// +// 返回值语义: +// - nil:未配置 → 调用方应回退到默认值 [401, 403, 429] +// - 长度为 0 的切片:管理员显式置空 → 关闭按状态码触发的同账号重试 +// - 非空切片:去重、过滤为合法 HTTP 状态码(100-599)后的覆盖列表 +func (a *Account) GetPoolModeRetryStatusCodes() []int { + if a == nil || a.Credentials == nil { + return nil + } + raw, ok := a.Credentials["pool_mode_retry_status_codes"] + if !ok || raw == nil { + return nil + } + arr, ok := raw.([]any) + if !ok { + return nil + } + seen := make(map[int]struct{}, len(arr)) + codes := make([]int, 0, len(arr)) + for _, v := range arr { + var code int + switch n := v.(type) { + case float64: + code = int(n) + case int: + code = n + case int64: + code = int(n) + case json.Number: + i, err := n.Int64() + if err != nil { + continue + } + code = int(i) + case string: + i, err := strconv.Atoi(strings.TrimSpace(n)) + if err != nil { + continue + } + code = i + default: + continue + } + if code < 100 || code > 599 { + continue + } + if _, exists := seen[code]; exists { + continue + } + seen[code] = struct{}{} + codes = append(codes, code) + } + sort.Ints(codes) + return codes +} + +// IsPoolModeRetryableStatus 在账号上下文中判断给定状态码是否应触发同账号重试。 +// 若账号未配置 pool_mode_retry_status_codes,则回退到默认列表。 +func (a *Account) IsPoolModeRetryableStatus(statusCode int) bool { + codes := a.GetPoolModeRetryStatusCodes() + if codes == nil { + return isPoolModeRetryableStatus(statusCode) + } + for _, c := range codes { + if c == statusCode { + return true + } + } + return false } func (a *Account) GetCustomErrorCodes() []int { diff --git a/backend/internal/service/account_pool_retry_status_codes_test.go b/backend/internal/service/account_pool_retry_status_codes_test.go new file mode 100644 index 00000000..c0b9d7ab --- /dev/null +++ b/backend/internal/service/account_pool_retry_status_codes_test.go @@ -0,0 +1,193 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetPoolModeRetryStatusCodes(t *testing.T) { + tests := []struct { + name string + account *Account + expected []int + }{ + { + name: "nil_account_returns_nil", + account: nil, + expected: nil, + }, + { + name: "nil_credentials_returns_nil", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + }, + expected: nil, + }, + { + name: "missing_key_returns_nil", + account: &Account{ + Type: AccountTypeAPIKey, + Platform: PlatformOpenAI, + Credentials: map[string]any{"pool_mode": true}, + }, + expected: nil, + }, + { + name: "empty_slice_is_preserved", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{}, + }, + }, + expected: []int{}, + }, + { + name: "float64_values_from_json_are_normalized", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{float64(429), float64(401), float64(403)}, + }, + }, + expected: []int{401, 403, 429}, + }, + { + name: "json_number_values_supported", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{json.Number("502"), json.Number("503")}, + }, + }, + expected: []int{502, 503}, + }, + { + name: "string_values_supported", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{"520", "529"}, + }, + }, + expected: []int{520, 529}, + }, + { + name: "duplicates_are_deduped", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{float64(429), float64(429), float64(401)}, + }, + }, + expected: []int{401, 429}, + }, + { + name: "out_of_range_values_dropped", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{float64(99), float64(600), float64(429)}, + }, + }, + expected: []int{429}, + }, + { + name: "invalid_string_dropped", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{"oops", float64(429)}, + }, + }, + expected: []int{429}, + }, + { + name: "non_array_value_returns_nil", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": "not-an-array", + }, + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.account.GetPoolModeRetryStatusCodes()) + }) + } +} + +func TestIsPoolModeRetryableStatus_Account(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + expected bool + }{ + { + name: "nil_account_falls_back_to_default_401", + account: nil, + statusCode: 401, + expected: true, + }, + { + name: "nil_account_falls_back_to_default_500", + account: nil, + statusCode: 500, + expected: false, + }, + { + name: "unconfigured_uses_default_403", + account: &Account{ + Credentials: map[string]any{"pool_mode": true}, + }, + statusCode: 403, + expected: true, + }, + { + name: "unconfigured_uses_default_502_false", + account: &Account{ + Credentials: map[string]any{"pool_mode": true}, + }, + statusCode: 502, + expected: false, + }, + { + name: "configured_list_overrides_default_401_dropped", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{float64(502), float64(503)}, + }, + }, + statusCode: 401, + expected: false, + }, + { + name: "configured_list_overrides_default_502_added", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{float64(502), float64(503)}, + }, + }, + statusCode: 502, + expected: true, + }, + { + name: "empty_list_disables_all_default_codes", + account: &Account{ + Credentials: map[string]any{ + "pool_mode_retry_status_codes": []any{}, + }, + }, + statusCode: 429, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.account.IsPoolModeRetryableStatus(tt.statusCode)) + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 64f3710b..571772e9 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -61,7 +61,7 @@ type AccountRepository interface { ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error - SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error + SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error ClearTempUnschedulable(ctx context.Context, id int64) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a1537252..c20a8c8c 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -162,7 +162,7 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt panic("unexpected SetRateLimited call") } -func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { +func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { panic("unexpected SetModelRateLimit call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 397004ac..b79c794f 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -597,6 +597,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, "Failed to create request") } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // Set common headers req.Header.Set("Content-Type", "application/json") @@ -676,6 +677,7 @@ func (s *AccountTestService) testOpenAIChatCompletionsConnection( if err != nil { return s.sendErrorAndEnd(c, "Failed to create Chat Completions request") } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Authorization", "Bearer "+authToken) @@ -756,6 +758,7 @@ func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, "Failed to create request") } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -1564,6 +1567,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C if err != nil { return s.sendErrorAndEnd(c, "Failed to create request") } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+authToken) @@ -1652,6 +1656,7 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co if err != nil { return s.sendErrorAndEnd(c, "Failed to create request") } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Host = "chatgpt.com" req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("Content-Type", "application/json") diff --git a/backend/internal/service/account_test_service_openai_compact_test.go b/backend/internal/service/account_test_service_openai_compact_test.go index 9eb98fdc..c9849e04 100644 --- a/backend/internal/service/account_test_service_openai_compact_test.go +++ b/backend/internal/service/account_test_service_openai_compact_test.go @@ -57,6 +57,7 @@ func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersi require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept")) require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version")) require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id")) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent")) require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go index 257159c4..9c24070c 100644 --- a/backend/internal/service/account_test_service_openai_image_test.go +++ b/backend/internal/service/account_test_service_openai_image_test.go @@ -45,6 +45,8 @@ func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *tes err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat") require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool") require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") require.Contains(t, rec.Body.String(), "\"success\":true") @@ -83,6 +85,7 @@ func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing. err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat") require.NoError(t, err) require.NotNil(t, upstream.lastReq) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String()) require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization")) require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 9844957a..910567fb 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -129,6 +129,8 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") require.NoError(t, err) + require.Len(t, upstream.requests, 1) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.requests[0].Context())) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"]) @@ -372,6 +374,7 @@ func TestAccountTestService_OpenAIAPIKeyResponsesUnsupportedUsesChatCompletionsP err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "hello", "") require.NoError(t, err) require.NotNil(t, upstream.lastReq) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization")) require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index cdc5217e..5352df6f 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -17,9 +17,13 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" ) @@ -48,6 +52,7 @@ type AdminService interface { GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) + GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error @@ -72,6 +77,9 @@ type AdminService interface { GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) + // UpdateAccountExtra 仅对 Extra 做 JSONB 增量合并(key 级覆盖),不会影响其它字段或运行态键。 + // 用于刷新流程持久化 account_uuid / org_uuid 等少量键,避免被全量快照覆盖。 + UpdateAccountExtra(ctx context.Context, id int64, updates map[string]any) error DeleteAccount(ctx context.Context, id int64) error RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error) @@ -212,6 +220,7 @@ type CreateGroupInput struct { RequireOAuthOnly bool RequirePrivacySet bool MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + ModelsListConfig GroupModelsListConfig // RPMLimit 分组 RPM 上限(0 = 不限制) RPMLimit int // 从指定分组复制账号(创建分组后在同一事务内绑定) @@ -252,6 +261,7 @@ type UpdateGroupInput struct { RequireOAuthOnly *bool RequirePrivacySet *bool MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + ModelsListConfig *GroupModelsListConfig // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 RPMLimit *int // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) @@ -1579,6 +1589,80 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro return s.groupRepo.GetByID(ctx, id) } +func (s *adminServiceImpl) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) { + platform = strings.TrimSpace(platform) + if id > 0 { + group, err := s.groupRepo.GetByIDLite(ctx, id) + if err != nil { + return nil, err + } + if platform == "" { + platform = group.Platform + } + } + if platform == "" { + platform = PlatformAnthropic + } + + candidates := defaultModelsListCandidateIDs(platform) + if id <= 0 || s.accountRepo == nil { + return candidates, nil + } + + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, id) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}, len(candidates)) + for _, model := range candidates { + seen[model] = struct{}{} + } + for _, acc := range accounts { + if acc.Platform != platform { + continue + } + for model := range acc.GetModelMapping() { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + candidates = append(candidates, model) + } + } + return candidates, nil +} + +func defaultModelsListCandidateIDs(platform string) []string { + switch platform { + case PlatformOpenAI: + return openai.DefaultModelIDs() + case PlatformGemini: + ids := make([]string, 0, len(geminicli.DefaultModels)) + for _, model := range geminicli.DefaultModels { + ids = append(ids, model.ID) + } + return ids + case PlatformAntigravity: + models := antigravity.DefaultModels() + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids + default: + ids := make([]string, 0, len(claude.DefaultModels)) + for _, model := range claude.DefaultModels { + ids = append(ids, model.ID) + } + return ids + } +} + func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { if input.RateMultiplier <= 0 { return nil, errors.New("rate_multiplier must be > 0") @@ -1694,6 +1778,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + ModelsListConfig: normalizeGroupModelsListConfig(input.ModelsListConfig), RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) @@ -1941,6 +2026,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MessagesDispatchModelConfig != nil { group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) } + if input.ModelsListConfig != nil { + group.ModelsListConfig = normalizeGroupModelsListConfig(*input.ModelsListConfig) + } if input.RPMLimit != nil { group.RPMLimit = *input.RPMLimit } @@ -2587,6 +2675,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U return updated, nil } +// UpdateAccountExtra 仅对 Extra JSONB 做 key 级合并,避免覆盖其它运行态键 +// (如 model_rate_limits / passive_usage_* 等)。 +func (s *adminServiceImpl) UpdateAccountExtra(ctx context.Context, id int64, updates map[string]any) error { + if len(updates) == 0 { + return nil + } + return s.accountRepo.UpdateExtra(ctx, id, updates) +} + // BulkUpdateAccounts updates multiple accounts in one request. // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 2f764d67..d01b11e6 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -459,6 +459,22 @@ func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID panic("unexpected InvalidateAPIKeyRateLimit call") } +func (s *billingCacheStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + panic("unexpected GetUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + panic("unexpected SetUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + panic("unexpected DeleteUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + panic("unexpected IncrUserPlatformQuotaUsageCache call") +} + func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall { t.Helper() calls := make([]subscriptionInvalidateCall, 0, expected) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 951f324c..9882b010 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1312,22 +1312,6 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt return body, nil } -// isModelNotFoundError 检测是否为模型不存在的 404 错误 -func isModelNotFoundError(statusCode int, body []byte) bool { - if statusCode != 404 { - return false - } - - bodyStr := strings.ToLower(string(body)) - keywords := []string{"model not found", "unknown model", "not found"} - for _, keyword := range keywords { - if strings.Contains(bodyStr, keyword) { - return true - } - } - return true // 404 without specific message also treated as model not found -} - // Forward 转发 Claude 协议请求(Claude → Gemini 转换) // // 限流处理流程: @@ -1362,6 +1346,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) if mappedModel == "" { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 @@ -2112,6 +2097,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co mappedModel := s.getMappedModel(account, originalModel) if mappedModel == "" { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) } billingModel := mappedModel @@ -4461,6 +4447,14 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp } // extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景) +// +// Anthropic streaming 的 usage 字段分布在两类事件中: +// - message_start:嵌套在 event.message.usage(input_tokens、cache_creation_input_tokens、 +// cache_read_input_tokens 等输入侧字段) +// - message_delta:位于顶层 event.usage(流结束时的最终 output_tokens) +// +// 仅读取顶层 event.usage 会漏掉 message_start 的输入侧字段,导致流式透传请求落库的 +// usage_logs 记录 input_tokens=0。 func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) { if !strings.HasPrefix(line, "data: ") { return @@ -4470,8 +4464,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs if json.Unmarshal([]byte(dataStr), &event) != nil { return } - u, ok := event["usage"].(map[string]any) - if !ok { + var u map[string]any + if eventType, _ := event["type"].(string); eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + u, _ = msg["usage"].(map[string]any) + } + } else { + u, _ = event["usage"].(map[string]any) + } + if u == nil { return } if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 1eb1451e..22124374 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -1301,6 +1301,19 @@ func TestExtractSSEUsage(t *testing.T) { line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`, expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3}, }, + { + // Anthropic message_start 把 usage 嵌套在 message.usage 下, + // 必须从这里提取输入侧字段(含 cache_read/cache_creation_input_tokens)。 + name: "message_start nested usage with input/cache tokens", + line: `data: {"type":"message_start","message":{"id":"msg_01","usage":{"input_tokens":35576,"cache_creation_input_tokens":0,"cache_read_input_tokens":12000,"output_tokens":1}}}`, + expected: ClaudeUsage{InputTokens: 35576, OutputTokens: 1, CacheReadInputTokens: 12000}, + }, + { + // message_start.message.usage.cache_creation 内的 5m/1h 明细也要解析。 + name: "message_start nested usage with cache_creation breakdown", + line: `data: {"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, + expected: ClaudeUsage{InputTokens: 100, CacheCreation5mTokens: 30, CacheCreation1hTokens: 70}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1311,6 +1324,29 @@ func TestExtractSSEUsage(t *testing.T) { } } +// TestExtractSSEUsage_StreamingSequence 复现 issue #2332:完整的 Anthropic streaming +// 序列(message_start → message_delta)必须把两类事件中的 usage 字段都汇入同一份累计值, +// 否则透传账号产出的 usage_logs 会出现 input_tokens=0、仅有 output_tokens 的"残缺"记录。 +func TestExtractSSEUsage_StreamingSequence(t *testing.T) { + svc := &AntigravityGatewayService{} + usage := &ClaudeUsage{} + + // 1) message_start:携带完整输入侧 usage(input_tokens + cache_read) + svc.extractSSEUsage( + `data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","content":[],"model":"claude-opus-4-6","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":35576,"cache_creation_input_tokens":0,"cache_read_input_tokens":12000,"output_tokens":1}}}`, + usage, + ) + // 2) message_delta:流结束时只带 output_tokens(无 input_tokens 字段) + svc.extractSSEUsage( + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":816}}`, + usage, + ) + + require.Equal(t, 35576, usage.InputTokens, "message_start 的 input_tokens 必须被记录,否则记账会缺失输入侧 token (#2332)") + require.Equal(t, 12000, usage.CacheReadInputTokens, "message_start 的 cache_read_input_tokens 必须被记录") + require.Equal(t, 816, usage.OutputTokens, "message_delta 的最终 output_tokens 必须被记录") +} + // TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测 func TestAntigravityClientWriter(t *testing.T) { t.Run("normal write succeeds", func(t *testing.T) { diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 35e130dc..c3e49458 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -94,7 +94,7 @@ func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int6 return nil } -func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time) error { +func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id int64, modelKey string, resetAt time.Time, reason ...string) error { s.modelRateLimitCalls = append(s.modelRateLimitCalls, modelRateLimitCall{accountID: id, modelKey: modelKey, resetAt: resetAt}) return nil } diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 3553a18a..74163179 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -87,6 +87,7 @@ type APIKeyAuthGroupSnapshot struct { AllowMessagesDispatch bool `json:"allow_messages_dispatch"` DefaultMappedModel string `json:"default_mapped_model,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + ModelsListConfig GroupModelsListConfig `json:"models_list_config,omitempty"` // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。 RPMLimit int `json:"rpm_limit"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index c752ce28..69c6086f 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 10 // v10: reload snapshots for group availability checks +const apiKeyAuthSnapshotVersion = 11 // v11: reload snapshots for custom models_list_config type apiKeyAuthCacheConfig struct { l1Size int @@ -272,6 +272,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, + ModelsListConfig: apiKey.Group.ModelsListConfig, RPMLimit: apiKey.Group.RPMLimit, } } @@ -342,6 +343,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, + ModelsListConfig: snapshot.Group.ModelsListConfig, RPMLimit: snapshot.Group.RPMLimit, } } diff --git a/backend/internal/service/api_key_auth_cache_version_test.go b/backend/internal/service/api_key_auth_cache_version_test.go new file mode 100644 index 00000000..5982e526 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_version_test.go @@ -0,0 +1,43 @@ +package service + +import "testing" + +func TestAPIKeyService_RejectsV10AuthSnapshotWithoutModelsListConfig(t *testing.T) { + groupID := int64(9) + svc := &APIKeyService{} + + apiKey, ok, err := svc.applyAuthCacheEntry("k-legacy-models-list", &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + Version: 10, + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "openai", + Platform: PlatformOpenAI, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + }) + + if err != nil { + t.Fatalf("expected stale snapshot to be ignored without error, got %v", err) + } + if ok { + t.Fatalf("expected v10 auth snapshot to be rejected after models_list_config was added") + } + if apiKey != nil { + t.Fatalf("expected no API key from stale snapshot, got %#v", apiKey) + } +} diff --git a/backend/internal/service/auth_email_oauth_auto.go b/backend/internal/service/auth_email_oauth_auto.go index 4db845c2..f86df14c 100644 --- a/backend/internal/service/auth_email_oauth_auto.go +++ b/backend/internal/service/auth_email_oauth_auto.go @@ -189,6 +189,8 @@ func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, } s.postAuthUserBootstrap(ctx, user, providerType, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { diff --git a/backend/internal/service/auth_email_oauth_auto_test.go b/backend/internal/service/auth_email_oauth_auto_test.go new file mode 100644 index 00000000..d128e6bb --- /dev/null +++ b/backend/internal/service/auth_email_oauth_auto_test.go @@ -0,0 +1,88 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newEmailOAuthAutoAuthService( + userRepo UserRepository, + settings map[string]string, + quotaRepo UserPlatformQuotaRepository, +) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + + return NewAuthService( + nil, // entClient — nil, updateUserSignupSource early return + userRepo, + nil, // redeemRepo — invitationCode="" 时不触发 + &refreshTokenCacheStub{}, + cfg, + settingService, + nil, // emailService + nil, // turnstileService + nil, // emailQueueService + nil, // promoService + nil, // defaultSubAssigner — nil, assignSubscriptions early return + nil, // affiliateService — nil, bindOAuthAffiliate early return + quotaRepo, + ) +} + +func TestEmailOAuthAuto_SnapshotsPlatformQuotaDefaults(t *testing.T) { + userRepo := &userRepoStub{nextID: 88} + quotaRepo := &userPlatformQuotaRepoStub{} + + svc := newEmailOAuthAutoAuthService( + userRepo, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"gemini": {"monthly": 100.0}}`, + }, + quotaRepo, + ) + + user, err := svc.createEmailOAuthUser( + context.Background(), + "newoauth@example.com", + "newoauth", + "github", + "", // invitationCode + "", // affiliateCode + ) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(88), user.ID) + + require.Len(t, quotaRepo.bulkInsertCalls, 1, "createEmailOAuthUser must snapshot platform quotas via BulkInsertInitial") + + records := quotaRepo.bulkInsertCalls[0] + var geminiRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "gemini" { + geminiRecord = &records[i] + break + } + } + require.NotNil(t, geminiRecord, "expected gemini platform record") + require.NotNil(t, geminiRecord.MonthlyLimitUSD) + require.InDelta(t, 100.0, *geminiRecord.MonthlyLimitUSD, 0.0001) +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index cf0be652..24d0eeee 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -283,6 +283,8 @@ func (s *AuthService) FinalizeOAuthEmailAccount( s.updateOAuthSignupSource(ctx, user.ID, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) return nil } diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index 3c02587b..5a78a348 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -112,6 +112,7 @@ func newOAuthEmailFlowAuthService( refreshTokenCache RefreshTokenCache, settings map[string]string, emailCache EmailCache, + quotaRepo UserPlatformQuotaRepository, // 新增 ) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ @@ -142,6 +143,7 @@ func newOAuthEmailFlowAuthService( nil, nil, nil, + quotaRepo, // 替换原来的 nil ) } @@ -175,6 +177,7 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -215,6 +218,7 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -274,6 +278,7 @@ func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -313,6 +318,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -360,6 +366,7 @@ func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) SettingKeyInvitationCodeEnabled: "true", }, &emailCacheStub{}, + nil, ) err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123") @@ -382,6 +389,7 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { SettingKeyRegistrationEnabled: "true", }, &emailCacheStub{}, + nil, ) err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "") @@ -389,3 +397,54 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "delete created oauth user") } + +func TestFinalizeOAuthEmailAccount_SnapshotsPlatformQuotaDefaults(t *testing.T) { + userRepo := &userRepoStub{nextID: 99} + quotaRepo := &userPlatformQuotaRepoStub{} + + authService := newOAuthEmailFlowAuthService( + userRepo, + nil, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 5.5}}`, + }, + &emailCacheStub{}, + quotaRepo, + ) + + user := &User{ + ID: 99, + Email: "newuser@example.com", + Role: RoleUser, + Status: StatusActive, + SignupSource: "oidc", + } + + err := authService.FinalizeOAuthEmailAccount( + context.Background(), + user, + "", + "oidc", + "", + ) + + require.NoError(t, err) + + require.Len(t, quotaRepo.bulkInsertCalls, 1, "snapshotPlatformQuotaDefaults must call BulkInsertInitial once on successful OAuth signup") + + records := quotaRepo.bulkInsertCalls[0] + var anthropicRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "anthropic" { + anthropicRecord = &records[i] + break + } + } + require.NotNil(t, anthropicRecord, "expected anthropic platform record") + require.Equal(t, int64(99), anthropicRecord.UserID) + require.NotNil(t, anthropicRecord.DailyLimitUSD) + require.InDelta(t, 5.5, *anthropicRecord.DailyLimitUSD, 0.0001) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 4e5b7b94..e4fa876c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -62,18 +62,19 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - entClient *dbent.Client - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService - affiliateService *AffiliateService - defaultSubAssigner DefaultSubscriptionAssigner + entClient *dbent.Client + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService + affiliateService *AffiliateService + defaultSubAssigner DefaultSubscriptionAssigner + userPlatformQuotaRepo UserPlatformQuotaRepository } type DefaultSubscriptionAssigner interface { @@ -81,9 +82,10 @@ type DefaultSubscriptionAssigner interface { } type signupGrantPlan struct { - Balance float64 - Concurrency int - Subscriptions []DefaultSubscriptionSetting + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + PlatformQuotas map[string]*DefaultPlatformQuotaSetting } // NewAuthService 创建认证服务实例 @@ -100,20 +102,22 @@ func NewAuthService( promoService *PromoService, defaultSubAssigner DefaultSubscriptionAssigner, affiliateService *AffiliateService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *AuthService { return &AuthService{ - entClient: entClient, - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, - affiliateService: affiliateService, - defaultSubAssigner: defaultSubAssigner, + entClient: entClient, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, + affiliateService: affiliateService, + defaultSubAssigner: defaultSubAssigner, + userPlatformQuotaRepo: userPlatformQuotaRepo, } } @@ -226,6 +230,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } s.postAuthUserBootstrap(ctx, user, "email", true) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) if s.affiliateService != nil { if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err) @@ -535,6 +541,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -685,6 +693,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) } } else { @@ -703,6 +713,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -764,18 +776,39 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + // ============ 全局 quota 装载(必须在 ResolveAuthSourceGrantSettings 之前) ============ + // 无论 auth source 是否 enabled,全局层都要先装载,确保 !enabled 早退路径也携带全局 quota。 + if quotas, err := s.settingService.GetDefaultPlatformQuotas(ctx); err == nil { + plan.PlatformQuotas = quotas + } else { + logger.LegacyPrintf("service.auth", "[Auth] Warning: load default platform quotas failed: %v (fail-open)", err) + } + // ============================================================================================ + resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) return plan } if !enabled { - return plan + return plan // plan.PlatformQuotas 已含全局层 } plan.Balance = resolved.Balance plan.Concurrency = resolved.Concurrency plan.Subscriptions = resolved.Subscriptions + + // ============ auth source quota merge(仅在 enabled 分支内) ============ + asQuotas := s.settingService.GetAuthSourcePlatformQuotas(ctx, signupSource) + if plan.PlatformQuotas != nil { + for platform, patch := range asQuotas { + if dst := plan.PlatformQuotas[platform]; dst != nil { + mergePlatformQuotaDefaults(dst, patch) + } + } + } + // ============================================================================== + return plan } @@ -1586,3 +1619,29 @@ func resolvedTokenVersion(user *User) int64 { fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff) return user.TokenVersion ^ fingerprint } + +// snapshotPlatformQuotaDefaults 把 plan.PlatformQuotas(4 platform × 3 window)以 +// BulkInsertInitial 形式写入 user_platform_quotas 表。失败 fail-open(仅 warn log)。 +func (s *AuthService) snapshotPlatformQuotaDefaults(ctx context.Context, userID int64, plan *signupGrantPlan) error { + if s.userPlatformQuotaRepo == nil || plan == nil || len(plan.PlatformQuotas) == 0 { + return nil + } + records := make([]UserPlatformQuotaRecord, 0, len(plan.PlatformQuotas)) + for platform, q := range plan.PlatformQuotas { + rec := UserPlatformQuotaRecord{ + UserID: userID, + Platform: platform, + } + if q != nil { + rec.DailyLimitUSD = q.DailyLimitUSD + rec.WeeklyLimitUSD = q.WeeklyLimitUSD + rec.MonthlyLimitUSD = q.MonthlyLimitUSD + } + records = append(records, rec) + } + if err := s.userPlatformQuotaRepo.BulkInsertInitial(ctx, records); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Warning: snapshot platform quota failed user=%d: %v (fail-open)", userID, err) + return nil // fail-open:返回 nil,让调用方继续 + } + return nil +} diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index 6845c1f4..87867395 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( emailSvc = service.NewEmailService(settingRepo, emailCache) } - svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil) + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil, nil) return svc, repo, client } @@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t }, } emailService := service.NewEmailService(nil, cache) - svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ ID: 41, @@ -820,8 +820,12 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) ( return ok, nil } -func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 53048b92..59ed8e52 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( values: settings, }, cfg) - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil) + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil, nil) return svc, repo, client } diff --git a/backend/internal/service/auth_service_platform_quota_test.go b/backend/internal/service/auth_service_platform_quota_test.go new file mode 100644 index 00000000..f58dc48c --- /dev/null +++ b/backend/internal/service/auth_service_platform_quota_test.go @@ -0,0 +1,157 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" +) + +// fakeInsertRecorder 记录 BulkInsertInitial 调用,实现 UserPlatformQuotaRepository port。 +type fakeInsertRecorder struct { + records []UserPlatformQuotaRecord + err error +} + +func (f *fakeInsertRecorder) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeInsertRecorder) BulkInsertInitial(_ context.Context, recs []UserPlatformQuotaRecord) error { + if f.err != nil { + return f.err + } + f.records = append(f.records, recs...) + return nil +} + +func (f *fakeInsertRecorder) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} + +func (f *fakeInsertRecorder) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeInsertRecorder) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeInsertRecorder) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error { + return nil +} + +func TestSnapshotPlatformQuotaDefaults_PassesToRepoBulkInsert(t *testing.T) { + fakeRepo := &fakeInsertRecorder{} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + "openai": {}, + "gemini": {}, + "antigravity": {}, + }, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 999, plan); err != nil { + t.Fatal(err) + } + if len(fakeRepo.records) != 4 { + t.Errorf("expected 4 records, got %d", len(fakeRepo.records)) + } + found := false + for _, r := range fakeRepo.records { + if r.UserID == 999 && r.Platform == "anthropic" && r.DailyLimitUSD != nil && *r.DailyLimitUSD == 5 { + found = true + } + } + if !found { + t.Error("anthropic daily = 5 not snapshotted") + } +} + +func TestSnapshotPlatformQuotaDefaults_NilPlanIsNoop(t *testing.T) { + fakeRepo := &fakeInsertRecorder{} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, nil); err != nil { + t.Errorf("nil plan should be noop, got %v", err) + } + if len(fakeRepo.records) != 0 { + t.Errorf("expected no records, got %d", len(fakeRepo.records)) + } +} + +func TestSnapshotPlatformQuotaDefaults_RepoErrorFailsOpen(t *testing.T) { + fakeRepo := &fakeInsertRecorder{err: fmt.Errorf("db down")} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil { + t.Errorf("fail-open: expected nil even on repo error, got %v", err) + } +} + +func TestSnapshotPlatformQuotaDefaults_NilRepoIsNoop(t *testing.T) { + s := &AuthService{userPlatformQuotaRepo: nil} + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{"a": {DailyLimitUSD: &five}}, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil { + t.Errorf("nil repo should be noop, got %v", err) + } +} + +// resolveSignupGrantPlan 测试:依赖完整的 AuthService 构造,需要 SettingService(含 settingRepoStub)。 +// settingRepoStub 已在 auth_service_register_test.go 中定义,同 package 可直接使用。 +func TestResolveSignupGrantPlan_GlobalQuotaLoadedBeforeAuthSource(t *testing.T) { + // 全局 quota JSON key(新格式) + settings := map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{ + "anthropic": {"daily": 10, "weekly": 50, "monthly": 200}, + "openai": {"daily": 5, "weekly": 25, "monthly": 100}, + "gemini": {"daily": 5, "weekly": 25, "monthly": 100}, + "antigravity": {"daily": 5, "weekly": 25, "monthly": 100} + }`, + } + svc := newAuthService(nil, settings, nil, nil) + plan := svc.resolveSignupGrantPlan(context.Background(), "email") + if plan.PlatformQuotas == nil { + t.Fatal("expected PlatformQuotas to be non-nil after loading global quota KVs") + } + q := plan.PlatformQuotas["anthropic"] + if q == nil { + t.Fatal("expected anthropic quota to be set") + } + if q.DailyLimitUSD == nil || *q.DailyLimitUSD != 10 { + t.Errorf("expected anthropic daily=10, got %v", q.DailyLimitUSD) + } +} + +// TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota 验证 P1 约束: +// !enabled 早退路径仍携带全局 quota(GetDefaultPlatformQuotas 在 ResolveAuthSourceGrantSettings 之前)。 +func TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota(t *testing.T) { + settings := map[string]string{ + SettingKeyRegistrationEnabled: "true", + // auth source 不配置(=> !enabled 路径) + SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 10, "weekly": 50, "monthly": 200}}`, + } + svc := newAuthService(nil, settings, nil, nil) + plan := svc.resolveSignupGrantPlan(context.Background(), "email") + // !enabled 路径:plan.PlatformQuotas 应已含全局层(不是 nil) + if plan.PlatformQuotas == nil { + t.Fatal("P1 violated: PlatformQuotas is nil even with global quota KVs set") + } + // P1 核心断言:disabled auth source 路径不能丢失全局 quota + if _, ok := plan.PlatformQuotas["anthropic"]; !ok { + t.Error("P1 violated: disabled auth source path dropped global platform quota") + } +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index ece02474..a7c0d260 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -73,6 +73,38 @@ type defaultSubscriptionAssignerStub struct { type refreshTokenCacheStub struct{} +type userPlatformQuotaRepoStub struct { + bulkInsertCalls [][]UserPlatformQuotaRecord + bulkInsertErr error +} + +func (s *userPlatformQuotaRepoStub) BulkInsertInitial(_ context.Context, records []UserPlatformQuotaRecord) error { + cloned := make([]UserPlatformQuotaRecord, len(records)) + copy(cloned, records) + s.bulkInsertCalls = append(s.bulkInsertCalls, cloned) + return s.bulkInsertErr +} + +func (s *userPlatformQuotaRepoStub) GetByUserPlatform(context.Context, int64, string) (*UserPlatformQuotaRecord, error) { + panic("unexpected GetByUserPlatform call") +} + +func (s *userPlatformQuotaRepoStub) ListByUser(context.Context, int64) ([]UserPlatformQuotaRecord, error) { + panic("unexpected ListByUser call") +} + +func (s *userPlatformQuotaRepoStub) IncrementUsageWithReset(context.Context, int64, string, float64, time.Time) error { + panic("unexpected IncrementUsageWithReset call") +} + +func (s *userPlatformQuotaRepoStub) UpsertForUser(context.Context, int64, []UserPlatformQuotaRecord) error { + panic("unexpected UpsertForUser call") +} + +func (s *userPlatformQuotaRepoStub) ResetExpiredWindow(context.Context, int64, string, string, time.Time) error { + panic("unexpected ResetExpiredWindow call") +} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -178,7 +210,7 @@ func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int6 return 0, nil } -func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { +func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache, quotaRepo UserPlatformQuotaRepository) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", @@ -213,6 +245,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, // promoService nil, // defaultSubAssigner nil, // affiliateService + quotaRepo, ) } @@ -220,7 +253,7 @@ func TestAuthService_Register_Disabled(t *testing.T) { repo := &userRepoStub{} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "false", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrRegDisabled) @@ -229,19 +262,62 @@ func TestAuthService_Register_Disabled(t *testing.T) { func TestAuthService_Register_DisabledByDefault(t *testing.T) { // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 repo := &userRepoStub{} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrRegDisabled) } +func TestAuthService_Register_SnapshotsPlatformQuotaDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 77} + quotaRepo := &userPlatformQuotaRepoStub{} + + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"openai": {"weekly": 12.34}}`, + }, nil, quotaRepo) + + _, user, err := service.Register(context.Background(), "newuser@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + + require.Len(t, quotaRepo.bulkInsertCalls, 1) + + records := quotaRepo.bulkInsertCalls[0] + var openaiRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "openai" { + openaiRecord = &records[i] + break + } + } + require.NotNil(t, openaiRecord, "expected openai platform record") + require.Equal(t, int64(77), openaiRecord.UserID) + require.NotNil(t, openaiRecord.WeeklyLimitUSD) + require.InDelta(t, 12.34, *openaiRecord.WeeklyLimitUSD, 0.0001) +} + +func TestAuthService_Register_DoesNotSnapshotOnDisabled(t *testing.T) { + repo := &userRepoStub{} + quotaRepo := &userPlatformQuotaRepoStub{} + + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "false", + }, nil, quotaRepo) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) + + require.Empty(t, quotaRepo.bulkInsertCalls, "registration rejected before user creation must not snapshot") +} + func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, nil) + }, nil, nil) // 应返回服务不可用错误,而不是允许绕过验证 _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "") @@ -254,7 +330,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, cache) + }, cache, nil) _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) @@ -268,7 +344,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, cache) + }, cache, nil) _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "") require.ErrorIs(t, err, ErrInvalidVerifyCode) @@ -279,7 +355,7 @@ func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -289,7 +365,7 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -299,7 +375,7 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) { repo := &userRepoStub{} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") require.ErrorIs(t, err, ErrEmailReserved) @@ -310,7 +386,7 @@ func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@other.com", "password") require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) @@ -327,7 +403,7 @@ func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`, - }, nil) + }, nil, nil) _, user, err := service.Register(context.Background(), "user@example.com", "password") require.NoError(t, err) @@ -340,7 +416,7 @@ func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, - }, nil) + }, nil, nil) err := service.SendVerifyCode(context.Background(), "user@other.com") require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) @@ -354,7 +430,7 @@ func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -365,7 +441,7 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { repo := &userRepoStub{createErr: ErrEmailExists} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -376,7 +452,7 @@ func TestAuthService_Register_Success(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) @@ -394,7 +470,7 @@ func TestAuthService_Register_Success(t *testing.T) { func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { repo := &userRepoStub{} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) // 创建用户并生成 token user := &User{ @@ -436,7 +512,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { TokenVersion: 1, } repo := &userRepoStub{user: user} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) // 创建过期 token service.cfg.JWT.ExpireHour = -1 @@ -453,7 +529,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { } func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 0 @@ -461,7 +537,7 @@ func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) } func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 90 @@ -469,7 +545,7 @@ func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { } func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 0 @@ -494,7 +570,7 @@ func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { } func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 90 @@ -525,7 +601,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { SettingKeyRegistrationEnabled: "true", SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") @@ -549,7 +625,7 @@ func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *tes SettingKeyAuthSourceDefaultEmailConcurrency: "7", SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") @@ -572,7 +648,7 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes SettingKeyAuthSourceDefaultEmailConcurrency: "88", SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-global@test.com", "password") @@ -595,7 +671,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul SettingKeyAuthSourceDefaultEmailConcurrency: "5", SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-merged@test.com", "password") @@ -618,7 +694,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} @@ -654,7 +730,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} @@ -677,7 +753,7 @@ func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.Ding DingTalk: dtCfg, } settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) - return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil) + return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil, nil) } // minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。 diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 3512822f..ce99c44d 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -55,6 +55,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier nil, // promoService nil, // defaultSubAssigner nil, // affiliateService + nil, // userPlatformQuotaRepo ) } diff --git a/backend/internal/service/bedrock_request.go b/backend/internal/service/bedrock_request.go index 8a1fb317..f4416ce3 100644 --- a/backend/internal/service/bedrock_request.go +++ b/backend/internal/service/bedrock_request.go @@ -185,6 +185,7 @@ func BuildBedrockURL(region, modelID string, stream bool) string { // 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl) // 6. 修复 thinking 字段兼容性(Opus 4.7 仅支持 adaptive,enabled 需要 budget_tokens) // 7. 清理 tool_use.id / tool_use_id 中 Bedrock 不接受的字符 +// 8. 根据最终 Bedrock beta tokens 剥离不再支持的 beta 字段 func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) { betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens, false) @@ -195,6 +196,9 @@ func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ( func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string, ccCompat bool) ([]byte, error) { var err error + betaTokens = filterBedrockBetaTokens(betaTokens) + body = sanitizeBedrockFieldsForBetaTokens(body, betaTokens) + // 注入 anthropic_version(Bedrock 要求) body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31") if err != nil { @@ -471,6 +475,8 @@ var bedrockSupportedBetaTokens = map[string]bool{ "tool-examples-2025-10-29": true, } +const bedrockContextManagementBetaToken = "context-management-2025-06-27" + // bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则 // Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头 var bedrockBetaTokenTransforms = map[string]string{ @@ -617,6 +623,22 @@ func filterBedrockBetaTokens(tokens []string) []string { return result } +func sanitizeBedrockFieldsForBetaTokens(body []byte, betaTokens []string) []byte { + if !containsBedrockBetaToken(betaTokens, bedrockContextManagementBetaToken) && gjson.GetBytes(body, "context_management").Exists() { + body, _ = sjson.DeleteBytes(body, "context_management") + } + return body +} + +func containsBedrockBetaToken(tokens []string, target string) bool { + for _, token := range tokens { + if token == target { + return true + } + } + return false +} + // bedrockToolUseIDRe 匹配 Bedrock 允许的 tool_use ID 字符(字母、数字、下划线、连字符) var bedrockToolUseIDRe = regexp.MustCompile(`[^a-zA-Z0-9_-]`) diff --git a/backend/internal/service/bedrock_request_test.go b/backend/internal/service/bedrock_request_test.go index 98942ba4..94f1a118 100644 --- a/backend/internal/service/bedrock_request_test.go +++ b/backend/internal/service/bedrock_request_test.go @@ -378,6 +378,67 @@ func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) { }) } +func TestPrepareBedrockRequestBodyWithTokens_ContextManagementRequiresSupportedBeta(t *testing.T) { + modelID := "us.anthropic.claude-opus-4-6-v1" + + t.Run("strips context_management when final tokens omit context-management beta", func(t *testing.T) { + input := `{ + "messages":[{"role":"user","content":"hi"}], + "max_tokens":100, + "context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]} + }` + betaTokens := []string{"context-1m-2025-08-07"} + originalTokens := append([]string(nil), betaTokens...) + + result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), modelID, betaTokens, false) + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + assert.Equal(t, originalTokens, betaTokens) + assert.Equal(t, originalTokens, bedrockAnthropicBetaNames(result)) + }) + + t.Run("leaves body without context_management otherwise intact", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), modelID, nil, false) + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + assert.Equal(t, "hi", gjson.GetBytes(result, "messages.0.content").String()) + assert.Equal(t, int64(100), gjson.GetBytes(result, "max_tokens").Int()) + }) + + t.Run("filters explicit unsupported context-management beta and strips field", func(t *testing.T) { + input := `{ + "messages":[{"role":"user","content":"hi"}], + "max_tokens":100, + "context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]} + }` + + result, err := PrepareBedrockRequestBodyWithTokens( + []byte(input), + modelID, + []string{bedrockContextManagementBetaToken, "context-1m-2025-08-07"}, + false, + ) + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + assert.Equal(t, []string{"context-1m-2025-08-07"}, bedrockAnthropicBetaNames(result)) + }) +} + +func bedrockAnthropicBetaNames(body []byte) []string { + arr := gjson.GetBytes(body, "anthropic_beta").Array() + names := make([]string, len(arr)) + for i, token := range arr { + names[i] = token.String() + } + return names +} + func TestBedrockCrossRegionPrefix(t *testing.T) { tests := []struct { region string diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 050db55b..2b7c06ba 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -11,18 +11,31 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "golang.org/x/sync/singleflight" ) // 错误定义 // 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 +// errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时 +// 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。 +var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable") + var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") + + // user × platform quota(HTTP 429 Too Many Requests + Retry-After header)。 + // 选用 429 而非 403:限额耗尽属于"暂时性资源用尽,重试可恢复"的场景(RFC 6585), + // 大量 SDK(如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After, + // 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。 + ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.") + ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.") + ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -94,6 +107,7 @@ type BillingCacheService struct { userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker + userPlatformQuotaRepo UserPlatformQuotaRepository cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup @@ -101,6 +115,7 @@ type BillingCacheService struct { cacheWriteMu sync.RWMutex stopped atomic.Bool balanceLoadSF singleflight.Group + quotaLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -117,6 +132,7 @@ func NewBillingCacheService( userRPMCache UserRPMCache, userGroupRateRepo UserGroupRateRepository, cfg *config.Config, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *BillingCacheService { svc := &BillingCacheService{ cache: cache, @@ -126,6 +142,7 @@ func NewBillingCacheService( userRPMCache: userRPMCache, userGroupRateRepo: userGroupRateRepo, cfg: cfg, + userPlatformQuotaRepo: userPlatformQuotaRepo, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() @@ -655,6 +672,30 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co }) } +// IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。 +// +// 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage, +// 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。 +// 写延迟通常 < 1ms(本地 Redis),换取 quota 视图实时性的取舍合理。 +// +// Redis 写失败用 ALERT 级 log;DB 持久化由 caller 单独 goroutine 兜底(gateway_service.go)。 +func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) { + if s.cache == nil { + return + } + if platform == "" || cost <= 0 { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil { + logger.LegacyPrintf("service.billing_cache", + "ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v", + userID, platform, cost, err) + } +} + // ============================================ // 统一检查方法 // ============================================ @@ -662,7 +703,8 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co // CheckBillingEligibility 检查用户是否有资格发起请求 // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) -func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error { +// platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。 +func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error { // 简易模式:跳过所有计费检查 if s.cfg.RunMode == config.RunModeSimple { return nil @@ -684,6 +726,13 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } } + // user × platform quota 仅在 standard(余额)模式生效;订阅模式豁免 + if !isSubscriptionMode { + if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil { + return err + } + } + // Check API Key rate limits (applies to both billing modes) if apiKey != nil && apiKey.HasRateLimits() { if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { @@ -975,3 +1024,257 @@ func circuitStateString(state billingCircuitBreakerState) string { return "unknown" } } + +// checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。 +// 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata)。 +// checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。 +// +// 流程(Redis-first / DB-fallback): +// 1. 先读 Redis cache;若命中且 SchemaVersion==1,直接用 entry 中的 limits 和 window_start 做校验, +// 免除 DB 查询。 +// 2. cache MISS 或旧版 entry(SchemaVersion==0)→ 查 DB 回填完整 entry(含 limits/window_start)。 +// 3. Redis 故障(err != nil)→ fail-open,查 DB 做一次性检查,不回填。 +func (s *BillingCacheService) checkUserPlatformQuotaEligibility( + ctx context.Context, + userID int64, + platform string, +) error { + if platform == "" || s.userPlatformQuotaRepo == nil { + return nil + } + + // cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。 + // 其他 check* 方法(balance/subscription/rate-limit)也有类似守卫。 + var ( + entry *UserPlatformQuotaCacheEntry + ok bool + cacheErr error + ) + if s.cache != nil { + entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform) + } else { + // 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查 + cacheErr = errBillingCacheUnavailable + } + + // --- cache HIT with current schema → 直接用 entry,不查 DB --- + if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 { + now := time.Now() + dailyUsage := entry.DailyUsageUSD + weeklyUsage := entry.WeeklyUsageUSD + monthlyUsage := entry.MonthlyUsageUSD + // 若窗口已更新(DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较, + // 同时记录新窗口起点用于后续刷新 cache entry。 + // 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力, + // 持久化数据始终正确。 + windowExpired := false + newDailyStart := entry.DailyWindowStart + newWeeklyStart := entry.WeeklyWindowStart + newMonthlyStart := entry.MonthlyWindowStart + if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) { + dailyUsage = 0 + windowExpired = true + dayStart := timezone.StartOfDay(now) + newDailyStart = &dayStart + } + if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) { + weeklyUsage = 0 + windowExpired = true + weekStart := timezone.StartOfWeek(now) + newWeeklyStart = &weekStart + } + if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) { + monthlyUsage = 0 + windowExpired = true + monthStart := now + newMonthlyStart = &monthStart + } + // 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis(而非 Delete)。 + // 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到 + // EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。 + // 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。 + // 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。 + // 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。 + // 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。 + if windowExpired && s.cache != nil { + refreshed := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: dailyUsage, + WeeklyUsageUSD: weeklyUsage, + MonthlyUsageUSD: monthlyUsage, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: entry.DailyLimitUSD, + WeeklyLimitUSD: entry.WeeklyLimitUSD, + MonthlyLimitUSD: entry.MonthlyLimitUSD, + DailyWindowStart: newDailyStart, + WeeklyWindowStart: newWeeklyStart, + MonthlyWindowStart: newMonthlyStart, + } + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil { + logger.LegacyPrintf("service.billing_cache", + "Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v", + userID, platform, setErr) + } + setCancel() + } + if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now)) + } + return nil + } + + // --- cache MISS、旧版 entry 或 Redis 故障 → 查 DB(singleflight 合并并发回源)--- + // 使用 DoChan 而非 Do:avoid sharing the first caller's ctx among all dedupe followers. + // 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。 + sfKey := strconv.FormatInt(userID, 10) + ":" + platform + ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) { + // 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx, + // 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer bgCancel() + return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform) + }) + var ( + v any + dbErr error + ) + select { + case res := <-ch: + v, dbErr = res.Val, res.Err + case <-ctx.Done(): + // 当前 caller 的 ctx 被取消:fail-open,不阻断 (此请求已无意义)。 + logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err()) + return nil + } + if dbErr != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr) + return nil + } + rec, _ := v.(*UserPlatformQuotaRecord) + if rec == nil { + return nil + } + + now := time.Now() + dailyUsage := rec.DailyUsageUSD + weeklyUsage := rec.WeeklyUsageUSD + monthlyUsage := rec.MonthlyUsageUSD + if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) { + dailyUsage = 0 + } + if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) { + weeklyUsage = 0 + } + if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) { + monthlyUsage = 0 + } + + // Redis 故障时 fail-open:不回填,直接用 DB 数据做一次性检查 + if cacheErr != nil { + if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) + } + return nil + } + + // cache MISS 或旧版 entry → 回填完整 entry(含 limits 和 window_start) + newEntry := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: dailyUsage, + WeeklyUsageUSD: weeklyUsage, + MonthlyUsageUSD: monthlyUsage, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: rec.DailyLimitUSD, + WeeklyLimitUSD: rec.WeeklyLimitUSD, + MonthlyLimitUSD: rec.MonthlyLimitUSD, + DailyWindowStart: rec.DailyWindowStart, + WeeklyWindowStart: rec.WeeklyWindowStart, + MonthlyWindowStart: rec.MonthlyWindowStart, + } + if s.cache != nil { + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + // 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms, + // 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败, + // 让下一次 preflight 仍然 MISS 并击穿到 DB(高并发下增大 DB 压力)。 + setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr) + } + setCancel() + } + + if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) + } + return nil +} + +// withWindowResetsMetadata 给 quota error 附加 window_resets_at metadata(RFC3339)。 +func withWindowResetsMetadata(err error, resetAt time.Time) error { + appErr, ok := err.(*infraerrors.ApplicationError) + if !ok || appErr == nil { + return err + } + return appErr.WithMetadata(map[string]string{ + "window_resets_at": resetAt.Format(time.RFC3339), + }) +} + +// nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。 +// 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。 +func nextDailyReset(now time.Time) time.Time { + return timezone.StartOfDay(now).AddDate(0, 0, 1) +} + +// nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。 +// 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。 +func nextWeeklyReset(now time.Time) time.Time { + return timezone.StartOfWeek(now).AddDate(0, 0, 7) +} + +// nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间(start + 30d)。 +// start 为 nil(未初始化)或已过期(now-start >= 30d,与 monthlyQuotaWindowExpired 同口径)时 +// 退化为 now+30d:过期窗口会在下次 increment 时重置为 now,下次重置即 now+30d; +// 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。 +func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time { + if start == nil || now.Sub(*start) >= 30*24*time.Hour { + return now.Add(30 * 24 * time.Hour) + } + return start.Add(30 * 24 * time.Hour) +} + +// quotaWindowExpired 判断窗口是否已过期:start 为 nil(未初始化)或在 currWindowStart 之前视为已过期。 +func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool { + if start == nil { + return true + } + return start.Before(currWindowStart) +} + +// monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。 +// 过期条件:now - start >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。 +// start 为 nil 时视为已过期(未初始化窗口)。 +func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool { + if start == nil { + return true + } + return now.Sub(*start) >= 30*24*time.Hour +} diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go index de66136f..cb71b886 100644 --- a/backend/internal/service/billing_cache_service_rpm_test.go +++ b/backend/internal/service/billing_cache_service_rpm_test.go @@ -74,7 +74,7 @@ func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGrou t.Helper() // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。 // 我们只直接测 checkRPM。 - svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}) + svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}, nil) t.Cleanup(svc.Stop) return svc } diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 962becf0..b443d97e 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -67,6 +67,22 @@ func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, ke return nil } +func (s *billingCacheMissStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (s *billingCacheMissStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *billingCacheMissStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return nil +} + +func (s *billingCacheMissStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + return nil +} + type balanceLoadUserRepoStub struct { mockUserRepo calls atomic.Int64 @@ -100,7 +116,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}, nil) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 849e24b8..bcd086fa 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -68,9 +68,25 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, return nil } +func (b *billingCacheWorkerStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (b *billingCacheWorkerStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + return nil +} + +func (b *billingCacheWorkerStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return nil +} + +func (b *billingCacheWorkerStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + return nil +} + func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil) t.Cleanup(svc.Stop) start := time.Now() @@ -92,7 +108,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_cache_service_user_platform_quota_test.go b/backend/internal/service/billing_cache_service_user_platform_quota_test.go new file mode 100644 index 00000000..57697ddb --- /dev/null +++ b/backend/internal/service/billing_cache_service_user_platform_quota_test.go @@ -0,0 +1,595 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" +) + +// fakeIncrCache 仅记录 IncrUserPlatformQuotaUsageCache 被调用的参数。 +type fakeIncrCache struct { + BillingCache + calls []incrCall +} + +type incrCall struct { + userID int64 + platform string + cost float64 + ttl time.Duration +} + +func (f *fakeIncrCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + f.calls = append(f.calls, incrCall{userID, platform, cost, ttl}) + return nil +} + +// IncrementUserPlatformQuotaUsage 已改为同步直写,不再走 worker。 +// 测试验证:同步调用立即调到 cache.IncrUserPlatformQuotaUsageCache。 +func TestIncrementUserPlatformQuotaUsage_SyncCallsCache(t *testing.T) { + fake := &fakeIncrCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 120 + + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + } + + s.IncrementUserPlatformQuotaUsage(101, "anthropic", 0.25) + s.IncrementUserPlatformQuotaUsage(101, "openai", 0.50) + + if len(fake.calls) != 2 { + t.Fatalf("expected 2 incr calls, got %d", len(fake.calls)) + } + if fake.calls[0] != (incrCall{101, "anthropic", 0.25, 120 * time.Second}) { + t.Errorf("call[0] = %+v", fake.calls[0]) + } + if fake.calls[1] != (incrCall{101, "openai", 0.50, 120 * time.Second}) { + t.Errorf("call[1] = %+v", fake.calls[1]) + } +} + +// ── T6 tests: checkUserPlatformQuotaEligibility ────────────────────────────── + +// fakeQuotaRepo 实现 UserPlatformQuotaRepository 最小子集 +type fakeQuotaRepo struct { + rec *UserPlatformQuotaRecord +} + +func (f *fakeQuotaRepo) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return f.rec, nil +} + +func (f *fakeQuotaRepo) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeQuotaRepo) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} + +func (f *fakeQuotaRepo) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeQuotaRepo) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeQuotaRepo) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error { + return nil +} + +// fakeFullCache 同时支持 Get + Set + Incr + Delete。 +// mu 保护 entry 和 deleteCalls,防止异步 goroutine 与主 goroutine 之间的 data race。 +type fakeFullCache struct { + BillingCache + mu sync.Mutex + entry *UserPlatformQuotaCacheEntry + deleteCalls int +} + +// getDeleteCalls 线程安全地读取 deleteCalls。 +func (f *fakeFullCache) getDeleteCalls() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.deleteCalls +} + +// getEntry 线程安全地读取 entry。 +func (f *fakeFullCache) getEntry() *UserPlatformQuotaCacheEntry { + f.mu.Lock() + defer f.mu.Unlock() + return f.entry +} + +func (f *fakeFullCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.entry == nil { + return nil, false, nil + } + return f.entry, true, nil +} + +func (f *fakeFullCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, e *UserPlatformQuotaCacheEntry, _ time.Duration) error { + f.mu.Lock() + defer f.mu.Unlock() + f.entry = e + return nil +} + +func (f *fakeFullCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error { + f.mu.Lock() + defer f.mu.Unlock() + f.deleteCalls++ + f.entry = nil + return nil +} + +func newServiceForPreflight(t *testing.T, repo UserPlatformQuotaRepository, cache BillingCache) *BillingCacheService { + t.Helper() + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + return &BillingCacheService{ + cache: cache, + cfg: cfg, + userPlatformQuotaRepo: repo, + } +} + +// currentDayStart 返回全局时区当天 0 点(与生产 timezone.StartOfDay 同口径,确保窗口有效)。 +func currentDayStart() *time.Time { + s := timezone.StartOfDay(time.Now()) + return &s +} + +func TestCheckUserPlatformQuotaEligibility_AllowsWhenUnderLimit(t *testing.T) { + daily := 10.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 4.5, + DailyLimitUSD: &daily, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_DailyExhausted(t *testing.T) { + daily := 5.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 5.0, + DailyLimitUSD: &daily, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("expected ErrUserPlatformDailyQuotaExhausted, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_NilLimitMeansUnlimited(t *testing.T) { + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 999, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + // DailyLimitUSD nil → 无限额 + }} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("nil limits should be unlimited, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_ZeroLimitImmediateBlock(t *testing.T) { + zero := 0.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &zero, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 0, + DailyLimitUSD: &zero, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("expected daily exhausted for limit=0, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_NoRecordMeansUnlimited(t *testing.T) { + repo := &fakeQuotaRepo{rec: nil} + cache := &fakeFullCache{} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("no record = unlimited, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB 验证旧版 entry(SchemaVersion=0) +// 触发 DB 回退路径,并在 DB 数据判断配额是否超限。 +// DB record 需设置有效的 window_start,否则 quotaWindowExpired 会将 usage 归零(nil 窗口视为已过期)。 +func TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB(t *testing.T) { + daily := 5.0 + dayStart := currentDayStart() + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, DailyUsageUSD: 6.0, + DailyWindowStart: dayStart, + }} + // SchemaVersion=0(旧 entry),应走 DB 路径 + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{DailyUsageUSD: 1.0}} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("旧版 entry 应走 DB 路径并报 daily exhausted, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache 验证 cache HIT 时若窗口已过期,usage 归零,用户放行。 +func TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache(t *testing.T) { + daily := 5.0 + past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) // 远古窗口起始,肯定已过期 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 10.0, // 超限,但窗口已过期 + DailyLimitUSD: &daily, + DailyWindowStart: &past, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if err != nil { + t.Errorf("过期窗口应归零放行, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache 验证: +// V1 HIT 路径检测到窗口过期时,用 reset 后的 entry 同步覆盖 Redis(而非 Delete): +// 1. 当前请求以本地清零值判断 → 放行 +// 2. cache entry 被替换为新 entry: usage 清零 + window_start 更新到当前窗口 +// limit 保留;这样并发 IncrUserPlatformQuotaUsage 的 Lua INCR 可正确累加到新窗口。 +func TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache(t *testing.T) { + daily := 5.0 + // 远古窗口起始,确保 quotaWindowExpired 返回 true + past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 10.0, // 超限,但窗口已过期 → 应被本地清零后放行 + DailyLimitUSD: &daily, + DailyWindowStart: &past, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + + // 本次 check 应放行(本地清零后 usage=0 < limit=5) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if err != nil { + t.Errorf("过期窗口应归零放行, got %v", err) + } + + // 验证 cache entry 已被刷新:usage 清零、limit 保留、window_start 更新到当前窗口 + refreshed := cache.getEntry() + if refreshed == nil { + t.Fatal("窗口过期后 cache entry 不应为 nil(应被 SetCache 覆盖,而非 Delete)") + } + if refreshed.DailyUsageUSD != 0 { + t.Errorf("刷新后 DailyUsageUSD = %v, want 0", refreshed.DailyUsageUSD) + } + if refreshed.DailyLimitUSD == nil || *refreshed.DailyLimitUSD != daily { + t.Errorf("刷新后 DailyLimitUSD = %v, want %v(保留)", refreshed.DailyLimitUSD, daily) + } + if refreshed.SchemaVersion != UserPlatformQuotaCacheSchemaV1 { + t.Errorf("刷新后 SchemaVersion = %d, want V1", refreshed.SchemaVersion) + } + if refreshed.DailyWindowStart == nil || refreshed.DailyWindowStart.Equal(past) { + t.Errorf("刷新后 DailyWindowStart = %v, 应更新到当前窗口而非保留 past=%v", refreshed.DailyWindowStart, past) + } +} + +// ── T5 tests: QueueUpdateUserPlatformQuotaUsage ─────────────────────────────── + +// ── C-NEW-1: monthlyQuotaWindowExpired 30 天滚动测试 ───────────────────────── + +func TestMonthlyQuotaWindowExpired_NilStart(t *testing.T) { + if !monthlyQuotaWindowExpired(nil, time.Now().UTC()) { + t.Error("nil start should be considered expired") + } +} + +func TestMonthlyQuotaWindowExpired_Expired(t *testing.T) { + start := time.Now().UTC().Add(-30 * 24 * time.Hour) + if !monthlyQuotaWindowExpired(&start, time.Now().UTC()) { + t.Error("start exactly 30 days ago should be expired") + } +} + +func TestMonthlyQuotaWindowExpired_Active(t *testing.T) { + start := time.Now().UTC().Add(-29 * 24 * time.Hour) + if monthlyQuotaWindowExpired(&start, time.Now().UTC()) { + t.Error("start 29 days ago should NOT be expired") + } +} + +// TestMonthlyQuotaWindowExpired_CrossMonthBoundary 验证跨自然月时 30 天未满不视为过期。 +func TestMonthlyQuotaWindowExpired_CrossMonthBoundary(t *testing.T) { + // 窗口起始 4 月 20 日;5 月 1 日只过了 11 天,不足 30 天 + start := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + if monthlyQuotaWindowExpired(&start, now) { + t.Error("11 days into window should NOT be expired (30-day rolling, not calendar month)") + } +} + +// TestNextMonthlyResetFrom 验证 30 天滚动重置时间计算。 +func TestNextMonthlyResetFrom_WithStart(t *testing.T) { + start := time.Date(2026, 5, 1, 10, 0, 0, 0, time.UTC) + want := start.Add(30 * 24 * time.Hour) + now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(&start, now) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom = %v, want %v", got, want) + } +} + +func TestNextMonthlyResetFrom_NilStart(t *testing.T) { + now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(nil, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(nil) = %v, want now+30d=%v", got, want) + } +} + +func TestNextMonthlyResetFrom_NilStart_NotEqualToNow(t *testing.T) { + now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(nil, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(nil) = %v, want %v (now+30d)", got, want) + } + if got.Equal(now) { + t.Error("nextMonthlyResetFrom(nil) must not return now (should be now+30d)") + } +} + +// TestNextMonthlyResetFrom_ExpiredStart 验证窗口已过期(now-start >= 30d)时, +// 下次重置时间为 now+30d,而非 start+30d(后者已是过去时间,会让 Retry-After 落回 +// fallback 并触发客户端紧凑重试)。 +func TestNextMonthlyResetFrom_ExpiredStart(t *testing.T) { + start := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) // 距 start 61 天,已过期 + got := nextMonthlyResetFrom(&start, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(expired) = %v, want now+30d=%v", got, want) + } + if !got.After(now) { + t.Error("expired window 的下次重置必须在 now 之后,不能是过去时间") + } +} + +func TestIncrementUserPlatformQuotaUsage_GuardsAgainstEmpty(t *testing.T) { + fake := &fakeIncrCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + } + + s.IncrementUserPlatformQuotaUsage(1, "", 0.5) // empty platform → noop + s.IncrementUserPlatformQuotaUsage(1, "openai", 0) // zero cost → noop + s.IncrementUserPlatformQuotaUsage(1, "openai", -0.1) // negative → noop + + if len(fake.calls) != 0 { + t.Errorf("expected 0 calls (all guarded), got %d", len(fake.calls)) + } +} + +// ── C-NEW-2: 订阅模式豁免 user×platform quota 检查 ────────────────────────── +// 通过直接调用 checkUserPlatformQuotaEligibility 验证: +// 1. standard 模式下 limit=0 → 拦截 +// 2. 订阅模式豁免通过 isSubscriptionMode 守卫体现 — 逻辑已在 CheckBillingEligibility 里加 !isSubscriptionMode 条件 +// 此处用单元测试直接验证底层 checkUserPlatformQuotaEligibility 的行为(quota 超限确实拦截), +// 而 subscription bypass 逻辑则在 CheckBillingEligibility 中通过条件判断保证,不绕过 sub eligibility 内部复杂依赖。 + +// fakeZeroQuotaCache 模拟 cache 命中且 daily limit=0(quota 耗尽)。 +type fakeZeroQuotaCache struct { + BillingCache + called bool +} + +func (f *fakeZeroQuotaCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) { + f.called = true + daily := 0.0 + entry := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 0, + DailyLimitUSD: &daily, + DailyWindowStart: func() *time.Time { t := time.Now().UTC(); return &t }(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + } + return entry, true, nil +} + +func (f *fakeZeroQuotaCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error { + return nil +} + +// SetUserPlatformQuotaCache 在 weekly/monthly window_start 为 nil 时,checkUserPlatform... +// 会触发"窗口过期 → SetCache 刷新"分支。fake 用 noop 避免 panic。 +func (f *fakeZeroQuotaCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, _ *UserPlatformQuotaCacheEntry, _ time.Duration) error { + return nil +} + +// GetSubscriptionCache 返回有效订阅(active、未过期、usage 远低于 limit), +// 用于支持 checkSubscriptionEligibility 通过,以便验证 quota 检查不被触发。 +func (f *fakeZeroQuotaCache) GetSubscriptionCache(_ context.Context, _ int64, _ int64) (*SubscriptionCacheData, error) { + return &SubscriptionCacheData{ + Status: SubscriptionStatusActive, + ExpiresAt: time.Now().Add(30 * 24 * time.Hour), + DailyUsage: 0, + WeeklyUsage: 0, + MonthlyUsage: 0, + }, nil +} + +func (f *fakeZeroQuotaCache) GetUserBalanceCache(_ context.Context, _ int64) (float64, bool, error) { + return 100.0, true, nil +} + +// TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero 验证: +// standard 模式下 limit=0 的 platform quota 确实会被拦截(守卫底层逻辑正确)。 +func TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero(t *testing.T) { + fake := &fakeZeroQuotaCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("standard mode with limit=0 should return ErrUserPlatformDailyQuotaExhausted, got: %v", err) + } + if !fake.called { + t.Error("GetUserPlatformQuotaCache should have been called in standard mode") + } +} + +// TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota 验证(C-NEW-2): +// 订阅模式用户不受 user×platform quota 拦截,GetUserPlatformQuotaCache 不应被调用。 +func TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota(t *testing.T) { + fake := &fakeZeroQuotaCache{} // GetUserPlatformQuotaCache 返回 limit=0,若被调用则拦截 + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + + subGroup := &Group{ + ID: 10, + SubscriptionType: "subscription", + Status: "active", + // 无 DailyLimitUSD → checkSubscriptionEligibility 不会因超限失败 + } + sub := &UserSubscription{Status: "active"} + user := &User{ID: 42} + + err := s.CheckBillingEligibility(context.Background(), user, nil, subGroup, sub, "anthropic") + // 订阅模式下不应收到任何 user×platform quota 错误 + if errors.Is(err, ErrUserPlatformDailyQuotaExhausted) || + errors.Is(err, ErrUserPlatformWeeklyQuotaExhausted) || + errors.Is(err, ErrUserPlatformMonthlyQuotaExhausted) { + t.Errorf("subscription mode should bypass user×platform quota, got: %v", err) + } + // GetUserPlatformQuotaCache 不应被调用 + if fake.called { + t.Error("GetUserPlatformQuotaCache must NOT be called in subscription mode (C-NEW-2)") + } +} + +// TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota 验证: +// 非订阅模式(group=nil)用户 platform quota 超限时被拦截,quota cache 被查询。 +func TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota(t *testing.T) { + called := &fakeZeroQuotaCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: called, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + err := s.checkUserPlatformQuotaEligibility(context.Background(), 99, "openai") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("non-subscription mode quota check should block, got: %v", err) + } + if !called.called { + t.Error("GetUserPlatformQuotaCache should be consulted in non-subscription mode") + } +} + +// ── B-3: monthlyQuotaWindowExpired 30 天边界表驱动测试 ──────────────────────── +// 覆盖 4 个必须场景: +// 1. 恰好 30 天 → expired +// 2. 30*24h - 1ns → not expired +// 3. 跨月末(2024-02-28 → 2024-03-29T00:00:01Z)→ expired +// 4. 跨年(2024-12-15 → 2025-01-14T00:00:01Z)→ expired +// +// repo 层 monthlyMaybeReset 不可导出,通过 service 层 monthlyQuotaWindowExpired 间接覆盖。 +func TestMonthlyQuotaWindowExpired_BoundaryTable(t *testing.T) { + const thirtyDays = 30 * 24 * time.Hour + + cases := []struct { + name string + start time.Time + now time.Time + expired bool + }{ + { + name: "exactly 30 days → expired", + start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays), + expired: true, + }, + { + name: "30d minus 1ns → not expired", + start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays - 1), + expired: false, + }, + { + name: "cross month-end (Feb→Mar, 29d+1s) → expired", + start: time.Date(2024, 2, 28, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 3, 29, 0, 0, 1, 0, time.UTC), + expired: true, + }, + { + name: "cross year boundary (Dec→Jan, 30d+1s) → expired", + start: time.Date(2024, 12, 15, 0, 0, 0, 0, time.UTC), + now: time.Date(2025, 1, 14, 0, 0, 1, 0, time.UTC), + expired: true, + }, + } + + for _, tc := range cases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + got := monthlyQuotaWindowExpired(&tc.start, tc.now) + if got != tc.expired { + t.Errorf("monthlyQuotaWindowExpired(start=%v, now=%v) = %v, want %v", + tc.start, tc.now, got, tc.expired) + } + }) + } +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 47975c8c..940a827d 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/config" ) @@ -20,6 +21,32 @@ type APIKeyRateLimitCacheData struct { Window7d int64 `json:"window_7d"` } +// UserPlatformQuotaCacheEntry Redis hash 反序列化结果。 +// +// SchemaVersion 用于向后兼容: +// - 0(旧 entry,无 SchemaVersion 字段)→ 视为 cache MISS,强制 refresh +// - 1(当前版本)→ 包含 limits 和 window_start,可免 DB 查询 +// +// limit 字段为 nil 表示"无限额"(DB 中对应列为 NULL)。 +const UserPlatformQuotaCacheSchemaV1 = int64(1) + +type UserPlatformQuotaCacheEntry struct { + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + Version int64 + SchemaVersion int64 + + // 以下字段仅在 SchemaVersion >= 1 时有效 + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations @@ -39,6 +66,13 @@ type BillingCache interface { SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error + + // user × platform quota 缓存 + GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) + SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error + DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error + // IncrUserPlatformQuotaUsageCache 在缓存命中时累加用量;缓存未命中(key 不存在)静默返回 nil。 + IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) @@ -482,6 +516,7 @@ func (s *BillingService) computeTokenBreakdown( inputPrice := pricing.InputPricePerToken outputPrice := pricing.OutputPricePerToken cacheReadPrice := pricing.CacheReadPricePerToken + cacheCreationMultiplier := 1.0 tierMultiplier := 1.0 if usePriorityServiceTierPricing(serviceTier, pricing) { @@ -501,6 +536,13 @@ func (s *BillingService) computeTokenBreakdown( if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) { inputPrice *= pricing.LongContextInputMultiplier outputPrice *= pricing.LongContextOutputMultiplier + // 缓存读取本质上是输入侧的复用,应与 input 一同应用长上下文倍率; + // 否则 cache hit 越多,少计的费用越多(见 #2293)。 + cacheReadPrice *= pricing.LongContextInputMultiplier + // 缓存创建(cache_write)也是输入侧操作,三档价格(标准 / 5m / 1h) + // 都通过 computeCacheCreationCost 直接读取 pricing.*,不会经过这里 + // 的倍率修改,因此显式向下传一个倍率,避免长上下文场景下被漏乘。 + cacheCreationMultiplier = pricing.LongContextInputMultiplier } bd := &CostBreakdown{} @@ -523,7 +565,7 @@ func (s *BillingService) computeTokenBreakdown( } // 缓存创建费用 - bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens) + bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens, cacheCreationMultiplier) bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice @@ -543,16 +585,17 @@ func (s *BillingService) computeTokenBreakdown( } // computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。 -func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 { +// multiplier 用于长上下文等场景下的整体价格缩放(普通调用传 1.0 即可)。 +func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens, multiplier float64) float64 { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 - return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice * multiplier } - return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice*multiplier + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice*multiplier } - return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken + return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken * multiplier } // calculatePerRequestCost 按次/图片计费 diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index df3e3a0a..0ab1f50d 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -197,6 +197,138 @@ func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *t require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) } +// 回归测试 #2293:长上下文计费触发时,cache_read_tokens 也应应用 LongContextInputMultiplier。 +// 修复前:CacheReadCost = tokens * 0.25e-6 (漏乘倍率,少计费用)。 +// 修复后:CacheReadCost = tokens * 0.25e-6 * LongContextInputMultiplier(=2.0)。 +func TestCalculateCost_OpenAIGPT54LongContextAppliesMultiplierToCacheRead(t *testing.T) { + svc := newTestBillingService() + + // InputTokens + CacheReadTokens = 1000 + 300000 = 301000 > 272000 阈值 + tokens := UsageTokens{ + InputTokens: 1000, + CacheReadTokens: 300000, + OutputTokens: 1000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0 + expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5 + expectedCacheRead := float64(tokens.CacheReadTokens) * 0.25e-6 * 2.0 + + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10, + "cache_read_cost should be scaled by LongContextInputMultiplier when long-context pricing applies (issue #2293)") + + expectedTotal := expectedInput + expectedOutput + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) + require.InDelta(t, expectedTotal, cost.ActualCost, 1e-10) +} + +// 阴性测试:未触发长上下文时,cache_read_price 不应被错误地乘以倍率。 +func TestCalculateCost_OpenAIGPT54NoLongContextKeepsCacheReadAtBasePrice(t *testing.T) { + svc := newTestBillingService() + + // InputTokens + CacheReadTokens = 1000 + 100000 = 101000 < 272000 阈值,不触发长上下文 + tokens := UsageTokens{ + InputTokens: 1000, + CacheReadTokens: 100000, + OutputTokens: 1000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedCacheRead := float64(tokens.CacheReadTokens) * 0.25e-6 + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10, + "cache_read_cost should remain at base price when below long-context threshold") +} + +// 回归测试 #2816 follow-up:长上下文计费触发时,cache_creation_tokens 也应应用 +// LongContextInputMultiplier。computeCacheCreationCost 直接读取 pricing.* 价格, +// 不经过 computeTokenBreakdown 内的 inputPrice / cacheReadPrice 倍率修改,因此 +// 修复前 cache_creation 部分会按基础价计算,少计费用约 50%(默认倍率 2.0)。 +func TestCalculateCost_OpenAIGPT54LongContextAppliesMultiplierToCacheCreation(t *testing.T) { + svc := newTestBillingService() + + // InputTokens + CacheReadTokens = 1000 + 300000 = 301000 > 272000 阈值 + tokens := UsageTokens{ + InputTokens: 1000, + CacheReadTokens: 300000, + CacheCreationTokens: 10000, + OutputTokens: 1000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + // gpt-5.4 fallback: CacheCreationPricePerToken = 2.5e-6, LongContextInputMultiplier = 2.0 + expectedCacheCreation := float64(tokens.CacheCreationTokens) * 2.5e-6 * 2.0 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10, + "cache_creation_cost should be scaled by LongContextInputMultiplier when long-context pricing applies") +} + +// 阴性测试:未触发长上下文时,cache_creation_price 不应被错误地乘以倍率。 +func TestCalculateCost_OpenAIGPT54NoLongContextKeepsCacheCreationAtBasePrice(t *testing.T) { + svc := newTestBillingService() + + // InputTokens + CacheReadTokens = 1000 + 100000 = 101000 < 272000 阈值,不触发长上下文 + tokens := UsageTokens{ + InputTokens: 1000, + CacheReadTokens: 100000, + CacheCreationTokens: 10000, + OutputTokens: 1000, + } + + cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := float64(tokens.CacheCreationTokens) * 2.5e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10, + "cache_creation_cost should remain at base price when below long-context threshold") +} + +// 覆盖 5m / 1h ephemeral 分类计费路径:长上下文触发时两档价格都应被倍率缩放。 +// 使用手工构造的 pricing(参考 TestCalculateCost_SupportsCacheBreakdown 的写法) +// 以便同时控制 SupportsCacheBreakdown + 长上下文阈值。 +func TestCalculateCost_LongContextAppliesMultiplierToCacheCreation5mAnd1h(t *testing.T) { + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + CacheReadPricePerToken: 0.3e-6, + SupportsCacheBreakdown: true, + CacheCreation5mPrice: 4e-6, + CacheCreation1hPrice: 5e-6, + LongContextInputThreshold: 272000, + LongContextInputMultiplier: 2.0, + LongContextOutputMultiplier: 1.5, + }, + }, + } + + // InputTokens + CacheReadTokens = 1000 + 300000 = 301000 > 272000 阈值 + tokens := UsageTokens{ + InputTokens: 1000, + CacheReadTokens: 300000, + CacheCreation5mTokens: 8000, + CacheCreation1hTokens: 4000, + OutputTokens: 1000, + } + + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6 * 2.0 + expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6 * 2.0 + require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10, + "both 5m and 1h cache_creation prices should be scaled by LongContextInputMultiplier") +} + func TestGetFallbackPricing_FamilyMatching(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go index b5a889e1..ee1fca41 100644 --- a/backend/internal/service/content_moderation.go +++ b/backend/internal/service/content_moderation.go @@ -177,6 +177,7 @@ type ContentModerationConfigView struct { AllGroups bool `json:"all_groups"` GroupIDs []int64 `json:"group_ids"` RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` WorkerCount int `json:"worker_count"` QueueSize int `json:"queue_size"` BlockStatus int `json:"block_status"` @@ -210,6 +211,20 @@ type ContentModerationAPIKeyStatus struct { Configured bool `json:"configured"` } +type ContentModerationAPIKeyLoad struct { + Index int `json:"index"` + KeyHash string `json:"key_hash"` + Masked string `json:"masked"` + Status string `json:"status"` + Active int64 `json:"active"` + Total int64 `json:"total"` + Success int64 `json:"success"` + Errors int64 `json:"errors"` + AvgLatencyMS int64 `json:"avg_latency_ms"` + LastLatencyMS int `json:"last_latency_ms"` + LastHTTPStatus int `json:"last_http_status"` +} + type TestContentModerationAPIKeysInput struct { APIKeys []string `json:"api_keys"` BaseURL string `json:"base_url"` @@ -249,6 +264,7 @@ type UpdateContentModerationConfigInput struct { AllGroups *bool `json:"all_groups"` GroupIDs *[]int64 `json:"group_ids"` RecordNonHits *bool `json:"record_non_hits"` + Thresholds *map[string]float64 `json:"thresholds"` WorkerCount *int `json:"worker_count"` QueueSize *int `json:"queue_size"` BlockStatus *int `json:"block_status"` @@ -397,25 +413,35 @@ type ContentModerationCleanupResult struct { } type ContentModerationRuntimeStatus struct { - Enabled bool `json:"enabled"` - RiskControlEnabled bool `json:"risk_control_enabled"` - Mode string `json:"mode"` - WorkerCount int `json:"worker_count"` - MaxWorkers int `json:"max_workers"` - ActiveWorkers int `json:"active_workers"` - IdleWorkers int `json:"idle_workers"` - QueueSize int `json:"queue_size"` - QueueLength int `json:"queue_length"` - QueueUsagePercent float64 `json:"queue_usage_percent"` - Enqueued int64 `json:"enqueued"` - Dropped int64 `json:"dropped"` - Processed int64 `json:"processed"` - Errors int64 `json:"errors"` - APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` - FlaggedHashCount int64 `json:"flagged_hash_count"` - LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"` - LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"` - LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"` + Enabled bool `json:"enabled"` + RiskControlEnabled bool `json:"risk_control_enabled"` + Mode string `json:"mode"` + WorkerCount int `json:"worker_count"` + MaxWorkers int `json:"max_workers"` + ActiveWorkers int `json:"active_workers"` + IdleWorkers int `json:"idle_workers"` + QueueSize int `json:"queue_size"` + QueueLength int `json:"queue_length"` + QueueUsagePercent float64 `json:"queue_usage_percent"` + Enqueued int64 `json:"enqueued"` + Dropped int64 `json:"dropped"` + Processed int64 `json:"processed"` + Errors int64 `json:"errors"` + PreBlockActive int `json:"pre_block_active"` + PreBlockChecked int64 `json:"pre_block_checked"` + PreBlockAllowed int64 `json:"pre_block_allowed"` + PreBlockBlocked int64 `json:"pre_block_blocked"` + PreBlockErrors int64 `json:"pre_block_errors"` + PreBlockAvgLatencyMS int64 `json:"pre_block_avg_latency_ms"` + PreBlockAPIKeyActive int64 `json:"pre_block_api_key_active"` + PreBlockAPIKeyAvailableCount int64 `json:"pre_block_api_key_available_count"` + PreBlockAPIKeyTotalCalls int64 `json:"pre_block_api_key_total_calls"` + PreBlockAPIKeyLoads []ContentModerationAPIKeyLoad `json:"pre_block_api_key_loads"` + APIKeyStatuses []ContentModerationAPIKeyStatus `json:"api_key_statuses"` + FlaggedHashCount int64 `json:"flagged_hash_count"` + LastCleanupAt *time.Time `json:"last_cleanup_at,omitempty"` + LastCleanupDeletedHit int64 `json:"last_cleanup_deleted_hit"` + LastCleanupDeletedNonHit int64 `json:"last_cleanup_deleted_non_hit"` } type ContentModerationUnbanUserResult struct { @@ -464,6 +490,12 @@ type ContentModerationService struct { asyncDropped atomic.Int64 asyncProcessed atomic.Int64 asyncErrors atomic.Int64 + preBlockActive atomic.Int64 + preBlockChecked atomic.Int64 + preBlockAllowed atomic.Int64 + preBlockBlocked atomic.Int64 + preBlockErrors atomic.Int64 + preBlockLatencyTotalMS atomic.Int64 lastCleanupUnix atomic.Int64 lastCleanupDeletedHit atomic.Int64 lastCleanupDeletedNonHit atomic.Int64 @@ -472,10 +504,14 @@ type ContentModerationService struct { } type contentModerationTask struct { - input ContentModerationCheckInput - content ContentModerationInput - inputHash string - enqueuedAt time.Time + input ContentModerationCheckInput + content ContentModerationInput + inputHash string + log *ContentModerationLog + config *ContentModerationConfig + recordHash bool + applySideEffects bool + enqueuedAt time.Time } type contentModerationKeyHealth struct { @@ -489,6 +525,11 @@ type contentModerationKeyHealth struct { LastLatencyMS int LastHTTPStatus int LastTested bool + SyncActive int64 + SyncTotal int64 + SyncSuccess int64 + SyncErrors int64 + SyncLatencyMS int64 } func NewContentModerationService( @@ -607,6 +648,9 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat if input.RecordNonHits != nil { cfg.RecordNonHits = *input.RecordNonHits } + if input.Thresholds != nil { + cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), *input.Thresholds) + } if input.ClearAPIKey { cfg.APIKey = "" cfg.APIKeys = []string{} @@ -822,9 +866,11 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "protocol", input.Protocol, "text_runes", len([]rune(content.Text)), "image_count", len(content.Images)) + hashText := content.Hash() if cfg.Mode == ContentModerationModePreBlock { if cfg.KeywordBlockingMode != ContentModerationKeywordModeAPIOnly && len(cfg.BlockedKeywords) > 0 { if keyword, hit := matchBlockedKeyword(content.Text, cfg.BlockedKeywords); hit { + s.recordPreBlockSyncMetric(0, ContentModerationActionKeywordBlock) slog.Info("content_moderation.keyword_block", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -835,8 +881,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "keyword", keyword) scores := map[string]float64{contentModerationKeywordCategory: 1.0} log := s.buildLog(input, cfg, ContentModerationActionKeywordBlock, true, contentModerationKeywordCategory, 1.0, scores, content.ExcerptText(), nil, nil, "") - s.applyFlaggedSideEffects(ctx, cfg, log) - _ = s.repo.CreateLog(ctx, log) + s.enqueueRecord(input, cfg, log, hashText, false, true) return &ContentModerationDecision{ Allowed: false, Blocked: true, @@ -851,6 +896,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer } } if cfg.KeywordBlockingMode == ContentModerationKeywordModeKeywordOnly { + s.recordPreBlockSyncMetric(0, ContentModerationActionAllow) slog.Info("content_moderation.skip_api_keyword_only", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -860,13 +906,15 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer return allow, nil } } - hashText := content.Hash() if cfg.PreHashCheckEnabled && s.hashCache != nil { matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText) if err != nil { slog.Warn("content_moderation.hash_check_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) } if matched { + if cfg.Mode == ContentModerationModePreBlock { + s.recordPreBlockSyncMetric(0, ContentModerationActionHashBlock) + } slog.Info("content_moderation.hash_block", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -878,6 +926,9 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer if message != "" { message = fmt.Sprintf("%s(hash: %s)", message, hashText) } + scores := map[string]float64{"hash": 1.0} + log := s.buildLog(input, cfg, ContentModerationActionHashBlock, true, "hash", 1.0, scores, content.ExcerptText(), nil, nil, "") + s.enqueueRecord(input, cfg, log, hashText, false, false) return &ContentModerationDecision{ Allowed: false, Blocked: true, @@ -890,6 +941,9 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer } } if !cfg.shouldSample(hashText) { + if cfg.Mode == ContentModerationModePreBlock { + s.recordPreBlockSyncMetric(0, ContentModerationActionAllow) + } slog.Info("content_moderation.skip_sample_rate", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -900,6 +954,9 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer return allow, nil } if len(cfg.apiKeys()) == 0 { + if cfg.Mode == ContentModerationModePreBlock { + s.recordPreBlockSyncMetric(0, ContentModerationActionError) + } slog.Warn("content_moderation.skip_no_audit_api_keys", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -925,10 +982,18 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer func (s *ContentModerationService) checkSync(ctx context.Context, input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string, queueDelay *int, allowBlock bool) *ContentModerationDecision { allow := &ContentModerationDecision{Allowed: true, Action: ContentModerationActionAllow} + trackPreBlock := queueDelay == nil && allowBlock && cfg != nil && cfg.Mode == ContentModerationModePreBlock + if trackPreBlock { + s.preBlockActive.Add(1) + defer s.preBlockActive.Add(-1) + } start := time.Now() - result, err := s.callModeration(ctx, cfg, content.ModerationInput()) + result, err := s.callModeration(ctx, cfg, content.ModerationInput(), trackPreBlock) latency := int(time.Since(start).Milliseconds()) if err != nil { + if trackPreBlock { + s.recordPreBlockSyncMetric(latency, ContentModerationActionError) + } slog.Warn("content_moderation.audit_api_failed", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -957,6 +1022,9 @@ func (s *ContentModerationService) checkSync(ctx context.Context, input ContentM action = ContentModerationActionBlock blocked = true } + if trackPreBlock { + s.recordPreBlockSyncMetric(latency, action) + } slog.Info("content_moderation.audit_result", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -975,13 +1043,11 @@ func (s *ContentModerationService) checkSync(ctx context.Context, input ContentM "queue_delay_ms", queueDelay) if flagged || cfg.RecordNonHits { log := s.buildLog(input, cfg, action, flagged, highestCategory, highestScore, result.CategoryScores, content.ExcerptText(), &latency, queueDelay, "") - if flagged && s.hashCache != nil { - if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil { - slog.Warn("content_moderation.record_hash_failed", "user_id", input.UserID, "endpoint", input.Endpoint, "error", err) - } + if queueDelay == nil && cfg.Mode == ContentModerationModePreBlock { + s.enqueueRecord(input, cfg, log, hashText, flagged, flagged) + } else { + s.persistContentModerationLog(ctx, cfg, log, hashText, flagged, flagged) } - s.applyFlaggedSideEffects(ctx, cfg, log) - _ = s.repo.CreateLog(ctx, log) } if blocked { return &ContentModerationDecision{ @@ -1007,6 +1073,25 @@ func (s *ContentModerationService) checkSync(ctx context.Context, input ContentM } } +func (s *ContentModerationService) recordPreBlockSyncMetric(latencyMS int, action string) { + if s == nil { + return + } + s.preBlockChecked.Add(1) + if latencyMS < 0 { + latencyMS = 0 + } + s.preBlockLatencyTotalMS.Add(int64(latencyMS)) + switch action { + case ContentModerationActionBlock, ContentModerationActionHashBlock, ContentModerationActionKeywordBlock: + s.preBlockBlocked.Add(1) + case ContentModerationActionError: + s.preBlockErrors.Add(1) + default: + s.preBlockAllowed.Add(1) + } +} + func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInput, cfg *ContentModerationConfig, content ContentModerationInput, hashText string) { if s == nil || s.asyncQueue == nil { return @@ -1035,11 +1120,49 @@ func (s *ContentModerationService) enqueueAsync(input ContentModerationCheckInpu } } +func (s *ContentModerationService) enqueueRecord(input ContentModerationCheckInput, cfg *ContentModerationConfig, log *ContentModerationLog, inputHash string, recordHash bool, applySideEffects bool) { + if s == nil || s.asyncQueue == nil || log == nil { + return + } + queueSize := defaultContentModerationQueueSize + if cfg != nil && cfg.QueueSize > 0 { + queueSize = cfg.QueueSize + } + if len(s.asyncQueue) >= queueSize { + slog.Warn("content_moderation.record_queue_full", + "user_id", input.UserID, + "endpoint", input.Endpoint, + "action", log.Action, + "queue_size", queueSize) + s.asyncDropped.Add(1) + return + } + task := contentModerationTask{ + input: input, + inputHash: inputHash, + log: log, + config: cloneContentModerationConfig(cfg), + recordHash: recordHash, + applySideEffects: applySideEffects, + enqueuedAt: time.Now(), + } + select { + case s.asyncQueue <- task: + s.asyncEnqueued.Add(1) + default: + slog.Warn("content_moderation.record_queue_full", + "user_id", input.UserID, + "endpoint", input.Endpoint, + "action", log.Action) + s.asyncDropped.Add(1) + } +} + func (s *ContentModerationService) worker(id int) { for { ctx, cancel := context.WithTimeout(context.Background(), maxContentModerationTimeoutMS*time.Millisecond+10*time.Second) cfg, err := s.loadConfig(ctx) - if err != nil || !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 || id >= cfg.WorkerCount { + if err != nil || id >= cfg.WorkerCount { cancel() time.Sleep(time.Second) continue @@ -1056,6 +1179,22 @@ func (s *ContentModerationService) worker(id int) { slog.Error("content_moderation.worker_panic", "worker_id", id, "recover", r) } }() + if task.log != nil { + s.asyncActive.Add(1) + defer s.asyncActive.Add(-1) + queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) + task.log.QueueDelayMS = &queueDelay + taskCfg := task.config + if taskCfg == nil { + taskCfg = cfg + } + s.persistContentModerationLog(ctx, taskCfg, task.log, task.inputHash, task.recordHash, task.applySideEffects) + s.asyncProcessed.Add(1) + return + } + if !cfg.Enabled || cfg.Mode == ContentModerationModeOff || len(cfg.apiKeys()) == 0 { + return + } if !cfg.includesGroup(task.input.GroupID) { return } @@ -1181,6 +1320,15 @@ func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModer if active > cfg.WorkerCount { active = cfg.WorkerCount } + preBlockActive := int(s.preBlockActive.Load()) + if preBlockActive < 0 { + preBlockActive = 0 + } + preBlockChecked := s.preBlockChecked.Load() + preBlockAvgLatency := int64(0) + if preBlockChecked > 0 { + preBlockAvgLatency = s.preBlockLatencyTotalMS.Load() / preBlockChecked + } queueLength := 0 if s.asyncQueue != nil { queueLength = len(s.asyncQueue) @@ -1203,25 +1351,35 @@ func (s *ContentModerationService) GetStatus(ctx context.Context) (*ContentModer lastCleanupAt = &t } return &ContentModerationRuntimeStatus{ - Enabled: cfg.Enabled, - RiskControlEnabled: riskEnabled, - Mode: cfg.Mode, - WorkerCount: cfg.WorkerCount, - MaxWorkers: maxContentModerationWorkerCount, - ActiveWorkers: active, - IdleWorkers: cfg.WorkerCount - active, - QueueSize: cfg.QueueSize, - QueueLength: queueLength, - QueueUsagePercent: queueUsage, - Enqueued: s.asyncEnqueued.Load(), - Dropped: s.asyncDropped.Load(), - Processed: s.asyncProcessed.Load(), - Errors: s.asyncErrors.Load(), - APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()), - FlaggedHashCount: flaggedHashCount, - LastCleanupAt: lastCleanupAt, - LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(), - LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(), + Enabled: cfg.Enabled, + RiskControlEnabled: riskEnabled, + Mode: cfg.Mode, + WorkerCount: cfg.WorkerCount, + MaxWorkers: maxContentModerationWorkerCount, + ActiveWorkers: active, + IdleWorkers: cfg.WorkerCount - active, + QueueSize: cfg.QueueSize, + QueueLength: queueLength, + QueueUsagePercent: queueUsage, + Enqueued: s.asyncEnqueued.Load(), + Dropped: s.asyncDropped.Load(), + Processed: s.asyncProcessed.Load(), + Errors: s.asyncErrors.Load(), + PreBlockActive: preBlockActive, + PreBlockChecked: preBlockChecked, + PreBlockAllowed: s.preBlockAllowed.Load(), + PreBlockBlocked: s.preBlockBlocked.Load(), + PreBlockErrors: s.preBlockErrors.Load(), + PreBlockAvgLatencyMS: preBlockAvgLatency, + PreBlockAPIKeyActive: s.preBlockAPIKeyActive(cfg.apiKeys()), + PreBlockAPIKeyAvailableCount: s.preBlockAPIKeyAvailableCount(cfg.apiKeys()), + PreBlockAPIKeyTotalCalls: s.preBlockAPIKeyTotalCalls(cfg.apiKeys()), + PreBlockAPIKeyLoads: s.preBlockAPIKeyLoads(cfg.apiKeys()), + APIKeyStatuses: s.apiKeyStatuses(cfg.apiKeys()), + FlaggedHashCount: flaggedHashCount, + LastCleanupAt: lastCleanupAt, + LastCleanupDeletedHit: s.lastCleanupDeletedHit.Load(), + LastCleanupDeletedNonHit: s.lastCleanupDeletedNonHit.Load(), }, nil } @@ -1320,7 +1478,7 @@ func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *Cont return nil } -func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any) (*moderationAPIResult, error) { +func (s *ContentModerationService) callModeration(ctx context.Context, cfg *ContentModerationConfig, input any, trackKeyLoad ...bool) (*moderationAPIResult, error) { attempts := cfg.RetryCount + 1 if attempts <= 0 { attempts = 1 @@ -1328,6 +1486,7 @@ func (s *ContentModerationService) callModeration(ctx context.Context, cfg *Cont if attempts > maxContentModerationRetryCount+1 { attempts = maxContentModerationRetryCount + 1 } + trackLoad := len(trackKeyLoad) > 0 && trackKeyLoad[0] var lastErr error for attempt := 0; attempt < attempts; attempt++ { key, ok := s.nextUsableAPIKey(cfg) @@ -1335,14 +1494,23 @@ func (s *ContentModerationService) callModeration(ctx context.Context, cfg *Cont lastErr = errors.New("no moderation api key available") break } + if trackLoad { + s.beginModerationAPIKeyCall(key) + } start := time.Now() httpStatus := 0 result, err := s.callModerationOnceWithInput(ctx, cfg, key, input, &httpStatus) latency := int(time.Since(start).Milliseconds()) if err == nil { + if trackLoad { + s.finishModerationAPIKeyCall(key, latency, true) + } s.markAPIKeySuccess(key, latency, httpStatus) return result, nil } + if trackLoad { + s.finishModerationAPIKeyCall(key, latency, false) + } s.markAPIKeyError(key, err.Error(), latency, httpStatus) lastErr = err if httpStatus == http.StatusBadRequest { @@ -1447,10 +1615,32 @@ func (s *ContentModerationService) buildLog(input ContentModerationCheckInput, c } } -func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) { - if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 { +func (s *ContentModerationService) persistContentModerationLog(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog, hashText string, recordHash bool, applySideEffects bool) { + if s == nil || log == nil { return } + if recordHash && s.hashCache != nil { + if err := s.hashCache.RecordFlaggedInputHash(ctx, hashText); err != nil { + slog.Warn("content_moderation.record_hash_failed", "user_id", contentModerationEmailUserID(log), "endpoint", log.Endpoint, "error", err) + } + } + autoBanJustApplied := false + if applySideEffects { + autoBanJustApplied = s.applyFlaggedAccountSideEffects(ctx, cfg, log) + s.sendFlaggedNotificationSideEffects(ctx, cfg, log, autoBanJustApplied) + } + if s.repo != nil { + if err := s.repo.CreateLog(ctx, log); err != nil { + slog.Warn("content_moderation.create_log_failed", "user_id", contentModerationEmailUserID(log), "endpoint", log.Endpoint, "action", log.Action, "error", err) + return + } + } +} + +func (s *ContentModerationService) applyFlaggedAccountSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog) bool { + if s == nil || cfg == nil || log == nil || !log.Flagged || log.UserID == nil || *log.UserID <= 0 { + return false + } count := 1 if s.repo != nil && cfg.ViolationWindowHours > 0 { since := time.Now().Add(-time.Duration(cfg.ViolationWindowHours) * time.Hour) @@ -1464,13 +1654,13 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, user, err := s.userRepo.GetByID(ctx, *log.UserID) if err != nil { slog.Warn("content_moderation.ban_get_user_failed", "user_id", *log.UserID, "error", err) - return + return false } if user.Status != StatusDisabled { user.Status = StatusDisabled if err := s.userRepo.Update(ctx, user); err != nil { slog.Warn("content_moderation.ban_update_user_failed", "user_id", *log.UserID, "error", err) - return + return false } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, *log.UserID) @@ -1479,7 +1669,13 @@ func (s *ContentModerationService) applyFlaggedSideEffects(ctx context.Context, } log.AutoBanned = true } + return autoBanJustApplied +} +func (s *ContentModerationService) sendFlaggedNotificationSideEffects(ctx context.Context, cfg *ContentModerationConfig, log *ContentModerationLog, autoBanJustApplied bool) { + if s == nil || cfg == nil || log == nil || !log.Flagged { + return + } if s.emailService == nil || strings.TrimSpace(log.UserEmail) == "" { return } @@ -1637,6 +1833,22 @@ func defaultContentModerationConfig() *ContentModerationConfig { } } +func cloneContentModerationConfig(cfg *ContentModerationConfig) *ContentModerationConfig { + if cfg == nil { + return nil + } + clone := *cfg + clone.APIKeys = append([]string(nil), cfg.APIKeys...) + clone.GroupIDs = append([]int64(nil), cfg.GroupIDs...) + clone.BlockedKeywords = append([]string(nil), cfg.BlockedKeywords...) + clone.Thresholds = cloneFloatMap(cfg.Thresholds) + clone.ModelFilter = ContentModerationModelFilter{ + Type: cfg.ModelFilter.Type, + Models: append([]string(nil), cfg.ModelFilter.Models...), + } + return &clone +} + func (cfg *ContentModerationConfig) normalize() { if cfg.APIKey != "" { cfg.APIKeys = normalizeModerationAPIKeys(append(cfg.APIKeys, cfg.APIKey)) @@ -1802,6 +2014,40 @@ func (s *ContentModerationService) isAPIKeyFrozen(key string, now time.Time) boo return state != nil && state.FrozenUntil.After(now) } +func (s *ContentModerationService) beginModerationAPIKeyCall(key string) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + state.SyncActive++ +} + +func (s *ContentModerationService) finishModerationAPIKeyCall(key string, latencyMS int, success bool) { + hash := moderationAPIKeyHash(key) + if hash == "" || s == nil { + return + } + if latencyMS < 0 { + latencyMS = 0 + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.ensureAPIKeyHealthLocked(hash, maskSecretTail(key)) + if state.SyncActive > 0 { + state.SyncActive-- + } + state.SyncTotal++ + state.SyncLatencyMS += int64(latencyMS) + if success { + state.SyncSuccess++ + return + } + state.SyncErrors++ +} + func (s *ContentModerationService) markAPIKeySuccess(key string, latencyMS int, httpStatus int) { hash := moderationAPIKeyHash(key) if hash == "" || s == nil { @@ -1894,6 +2140,7 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con AllGroups: cfg.AllGroups, GroupIDs: append([]int64(nil), cfg.GroupIDs...), RecordNonHits: cfg.RecordNonHits, + Thresholds: cloneFloatMap(cfg.Thresholds), WorkerCount: cfg.WorkerCount, QueueSize: cfg.QueueSize, BlockStatus: cfg.BlockStatus, @@ -1920,6 +2167,71 @@ func (s *ContentModerationService) apiKeyStatuses(keys []string) []ContentModera return out } +func (s *ContentModerationService) preBlockAPIKeyLoads(keys []string) []ContentModerationAPIKeyLoad { + out := make([]ContentModerationAPIKeyLoad, 0, len(keys)) + for idx, key := range keys { + out = append(out, s.preBlockAPIKeyLoadForHash(idx, moderationAPIKeyHash(key), maskSecretTail(key))) + } + return out +} + +func (s *ContentModerationService) preBlockAPIKeyActive(keys []string) int64 { + var total int64 + for _, item := range s.preBlockAPIKeyLoads(keys) { + total += item.Active + } + return total +} + +func (s *ContentModerationService) preBlockAPIKeyAvailableCount(keys []string) int64 { + now := time.Now() + var count int64 + for _, key := range keys { + if !s.isAPIKeyFrozen(key, now) { + count++ + } + } + return count +} + +func (s *ContentModerationService) preBlockAPIKeyTotalCalls(keys []string) int64 { + var total int64 + for _, item := range s.preBlockAPIKeyLoads(keys) { + total += item.Total + } + return total +} + +func (s *ContentModerationService) preBlockAPIKeyLoadForHash(index int, hash string, masked string) ContentModerationAPIKeyLoad { + load := ContentModerationAPIKeyLoad{ + Index: index, + KeyHash: hash, + Masked: masked, + Status: "unknown", + } + status := s.apiKeyStatusForHash(index, hash, masked, true) + load.Status = status.Status + load.LastLatencyMS = status.LastLatencyMS + load.LastHTTPStatus = status.LastHTTPStatus + if hash == "" || s == nil { + return load + } + s.keyHealthMu.Lock() + defer s.keyHealthMu.Unlock() + state := s.keyHealth[hash] + if state == nil { + return load + } + load.Active = state.SyncActive + load.Total = state.SyncTotal + load.Success = state.SyncSuccess + load.Errors = state.SyncErrors + if state.SyncTotal > 0 { + load.AvgLatencyMS = state.SyncLatencyMS / state.SyncTotal + } + return load +} + func (s *ContentModerationService) apiKeyStatusForHash(index int, hash string, masked string, configured bool) ContentModerationAPIKeyStatus { status := ContentModerationAPIKeyStatus{ Index: index, diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go index 60a99318..1fb72f36 100644 --- a/backend/internal/service/content_moderation_test.go +++ b/backend/internal/service/content_moderation_test.go @@ -3,9 +3,11 @@ package service import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -73,10 +75,13 @@ func (r *contentModerationTestSettingRepo) Delete(ctx context.Context, key strin } type contentModerationTestRepo struct { + mu sync.Mutex logs []ContentModerationLog } func (r *contentModerationTestRepo) CreateLog(ctx context.Context, log *ContentModerationLog) error { + r.mu.Lock() + defer r.mu.Unlock() if log != nil { r.logs = append(r.logs, *log) } @@ -88,14 +93,55 @@ func (r *contentModerationTestRepo) ListLogs(ctx context.Context, filter Content } func (r *contentModerationTestRepo) CountFlaggedByUserSince(ctx context.Context, userID int64, since time.Time) (int, error) { - return 0, nil + r.mu.Lock() + defer r.mu.Unlock() + count := 0 + for _, log := range r.logs { + if log.UserID == nil || *log.UserID != userID || !log.Flagged || log.Action == ContentModerationActionHashBlock { + continue + } + if log.CreatedAt.IsZero() || log.CreatedAt.Before(since) { + continue + } + count++ + } + return count, nil } func (r *contentModerationTestRepo) CleanupExpiredLogs(ctx context.Context, hitBefore time.Time, nonHitBefore time.Time) (*ContentModerationCleanupResult, error) { return &ContentModerationCleanupResult{}, nil } +func (r *contentModerationTestRepo) snapshotLogs() []ContentModerationLog { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]ContentModerationLog, len(r.logs)) + copy(out, r.logs) + return out +} + +func requireContentModerationLogCount(t *testing.T, repo *contentModerationTestRepo, want int) []ContentModerationLog { + t.Helper() + var logs []ContentModerationLog + require.Eventually(t, func() bool { + logs = repo.snapshotLogs() + return len(logs) == want + }, time.Second, 10*time.Millisecond) + return logs +} + +func requireRecordedHashCount(t *testing.T, cache *contentModerationTestHashCache, want int) []string { + t.Helper() + var hashes []string + require.Eventually(t, func() bool { + hashes = cache.snapshotRecorded() + return len(hashes) == want + }, time.Second, 10*time.Millisecond) + return hashes +} + type contentModerationTestHashCache struct { + mu sync.Mutex hashes map[string]struct{} recorded []string checked []string @@ -246,6 +292,8 @@ func (i *contentModerationTestAuthCacheInvalidator) InvalidateAuthCacheByGroupID } func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Context, inputHash string) error { + c.mu.Lock() + defer c.mu.Unlock() if c.hashes == nil { c.hashes = map[string]struct{}{} } @@ -255,6 +303,8 @@ func (c *contentModerationTestHashCache) RecordFlaggedInputHash(ctx context.Cont } func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.mu.Lock() + defer c.mu.Unlock() c.checked = append(c.checked, inputHash) if c.hasResultUsed { return c.hasResult, nil @@ -264,6 +314,8 @@ func (c *contentModerationTestHashCache) HasFlaggedInputHash(ctx context.Context } func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Context, inputHash string) (bool, error) { + c.mu.Lock() + defer c.mu.Unlock() c.deleted = append(c.deleted, inputHash) if c.hashes == nil { return false, nil @@ -276,15 +328,50 @@ func (c *contentModerationTestHashCache) DeleteFlaggedInputHash(ctx context.Cont } func (c *contentModerationTestHashCache) ClearFlaggedInputHashes(ctx context.Context) (int64, error) { + c.mu.Lock() + defer c.mu.Unlock() deleted := int64(len(c.hashes)) c.hashes = map[string]struct{}{} return deleted, nil } func (c *contentModerationTestHashCache) CountFlaggedInputHashes(ctx context.Context) (int64, error) { + c.mu.Lock() + defer c.mu.Unlock() return int64(len(c.hashes)), nil } +func (c *contentModerationTestHashCache) snapshotRecorded() []string { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]string, len(c.recorded)) + copy(out, c.recorded) + return out +} + +func (c *contentModerationTestHashCache) snapshotChecked() []string { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]string, len(c.checked)) + copy(out, c.checked) + return out +} + +func (c *contentModerationTestHashCache) hasHash(inputHash string) bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.hashes[inputHash] + return ok +} + +func (c *contentModerationTestHashCache) snapshotDeleted() []string { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]string, len(c.deleted)) + copy(out, c.deleted) + return out +} + func TestBuildContentModerationLog_RedactsInputExcerpt(t *testing.T) { svc := &ContentModerationService{} cfg := defaultContentModerationConfig() @@ -381,10 +468,10 @@ func TestContentModerationCheck_PreBlockKeywordHitSkipsUpstreamCall(t *testing.T require.True(t, decision.Blocked) require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) require.False(t, upstreamCalled, "keyword block must short-circuit upstream moderation call") - require.Len(t, repo.logs, 1) - require.True(t, repo.logs[0].Flagged) - require.Equal(t, ContentModerationActionKeywordBlock, repo.logs[0].Action) - require.Equal(t, contentModerationKeywordCategory, repo.logs[0].HighestCategory) + logs := requireContentModerationLogCount(t, repo, 1) + require.True(t, logs[0].Flagged) + require.Equal(t, ContentModerationActionKeywordBlock, logs[0].Action) + require.Equal(t, contentModerationKeywordCategory, logs[0].HighestCategory) } func TestContentModerationCheck_KeywordsIgnoredInObserveMode(t *testing.T) { @@ -474,7 +561,7 @@ func TestContentModerationCheck_KeywordOnlyStrategySkipsAPIOnMiss(t *testing.T) require.NoError(t, err) require.True(t, decision.Allowed, "keyword-only must allow misses without calling the API") require.False(t, upstreamCalled, "keyword-only must not call the upstream moderation API") - require.Len(t, repo.logs, 0) + require.Len(t, repo.snapshotLogs(), 0) } func TestContentModerationCheck_APIOnlyStrategyIgnoresKeywordList(t *testing.T) { @@ -545,7 +632,7 @@ func TestContentModerationCheck_ModelFilterAllAuditsEveryModel(t *testing.T) { require.True(t, decision.Blocked) require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) } - require.Len(t, repo.logs, 2) + requireContentModerationLogCount(t, repo, 2) } func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *testing.T) { @@ -571,8 +658,8 @@ func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *test require.True(t, decision.Allowed) require.False(t, decision.Blocked) require.Equal(t, ContentModerationActionAllow, decision.Action) - require.Len(t, repo.logs, 1) - require.Equal(t, "gpt-5.5", repo.logs[0].Model) + logs := requireContentModerationLogCount(t, repo, 1) + require.Equal(t, "gpt-5.5", logs[0].Model) } func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T) { @@ -598,8 +685,8 @@ func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T require.True(t, decision.Allowed) require.False(t, decision.Blocked) require.Equal(t, ContentModerationActionAllow, decision.Action) - require.Len(t, repo.logs, 1) - require.Equal(t, "gpt-5.5", repo.logs[0].Model) + logs := requireContentModerationLogCount(t, repo, 1) + require.Equal(t, "gpt-5.5", logs[0].Model) } func TestContentModerationLoadConfig_LegacyConfigDefaultsModelFilterToAll(t *testing.T) { @@ -639,8 +726,8 @@ func TestContentModerationCheck_ModelFilterUsesRequestedModelNotBodyModel(t *tes require.NoError(t, err) require.True(t, decision.Blocked) require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) - require.Len(t, repo.logs, 1) - require.Equal(t, "gpt-5.5", repo.logs[0].Model) + logs := requireContentModerationLogCount(t, repo, 1) + require.Equal(t, "gpt-5.5", logs[0].Model) } func defaultContentModerationModelFilterTestConfig() *ContentModerationConfig { @@ -726,6 +813,37 @@ func TestContentModerationUpdateConfig_ReplacesAPIKeysWhenRequested(t *testing.T require.Equal(t, []string{"sk-new-only"}, saved.apiKeys()) } +func TestContentModerationUpdateConfig_SavesCustomThresholds(t *testing.T) { + cfg := defaultContentModerationConfig() + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + repo := &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyContentModerationConfig: string(rawCfg), + }} + svc := NewContentModerationService(repo, nil, nil, nil, nil, nil, nil) + thresholds := map[string]float64{ + "sexual": 0.72, + "harassment": 1.25, + "unknown": 0.01, + } + + view, err := svc.UpdateConfig(context.Background(), UpdateContentModerationConfigInput{ + Thresholds: &thresholds, + }) + + require.NoError(t, err) + require.Equal(t, 0.72, view.Thresholds["sexual"]) + require.Equal(t, 1.0, view.Thresholds["harassment"]) + require.NotContains(t, view.Thresholds, "unknown") + + var saved ContentModerationConfig + require.NoError(t, json.Unmarshal([]byte(repo.values[SettingKeyContentModerationConfig]), &saved)) + require.Equal(t, 0.72, saved.Thresholds["sexual"]) + require.Equal(t, 1.0, saved.Thresholds["harassment"]) + require.NotContains(t, saved.Thresholds, "unknown") +} + func TestExtractContentModerationInput_AnthropicImageSourceOnlyParticipatesInMemory(t *testing.T) { body := []byte(`{ "messages": [ @@ -908,11 +1026,11 @@ func TestContentModerationCheck_OpenAIResponsesRecordsNonHitForCodexPayload(t *t require.NoError(t, err) require.False(t, decision.Blocked) - require.Len(t, repo.logs, 1) - require.False(t, repo.logs[0].Flagged) - require.Equal(t, ContentModerationActionAllow, repo.logs[0].Action) - require.Equal(t, "/responses", repo.logs[0].Endpoint) - require.Equal(t, "last user prompt", repo.logs[0].InputExcerpt) + logs := requireContentModerationLogCount(t, repo, 1) + require.False(t, logs[0].Flagged) + require.Equal(t, ContentModerationActionAllow, logs[0].Action) + require.Equal(t, "/responses", logs[0].Endpoint) + require.Equal(t, "last user prompt", logs[0].InputExcerpt) require.Equal(t, "last user prompt", moderationRequest.Input) } @@ -976,14 +1094,164 @@ func TestContentModerationCheck_PreBlockBlocksCodexResponsesLatestUserInput(t *t require.Equal(t, ContentModerationActionBlock, decision.Action) require.Equal(t, http.StatusUnavailableForLegalReasons, decision.StatusCode) require.Equal(t, "内容审计测试阻断", decision.Message) - require.Len(t, repo.logs, 1) - require.True(t, repo.logs[0].Flagged) - require.Equal(t, ContentModerationActionBlock, repo.logs[0].Action) - require.Equal(t, ContentModerationModePreBlock, repo.logs[0].Mode) - require.Equal(t, "latest blocked prompt", repo.logs[0].InputExcerpt) + logs := requireContentModerationLogCount(t, repo, 1) + require.True(t, logs[0].Flagged) + require.Equal(t, ContentModerationActionBlock, logs[0].Action) + require.Equal(t, ContentModerationModePreBlock, logs[0].Mode) + require.Equal(t, "latest blocked prompt", logs[0].InputExcerpt) require.Equal(t, "latest blocked prompt", moderationRequest.Input) } +func TestContentModerationStatusTracksPreBlockSyncMetrics(t *testing.T) { + var requestCount int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + score := 0.01 + if requestCount == 1 { + score = 0.9 + } + time.Sleep(5 * time.Millisecond) + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": score}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + &contentModerationTestRepo{}, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + for _, prompt := range []string{"blocked prompt", "clean prompt"} { + _, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(fmt.Sprintf(`{"messages":[{"role":"user","content":%q}]}`, prompt)), + }) + require.NoError(t, err) + } + + status, err := svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), status.PreBlockChecked) + require.Equal(t, int64(1), status.PreBlockAllowed) + require.Equal(t, int64(1), status.PreBlockBlocked) + require.Equal(t, int64(0), status.PreBlockErrors) + require.Equal(t, 0, status.PreBlockActive) + require.GreaterOrEqual(t, status.PreBlockAvgLatencyMS, int64(1)) +} + +func TestContentModerationStatusTracksPreBlockAPIKeyLoad(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.01}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-one", "sk-two"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + &contentModerationTestRepo{}, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + for idx := 0; idx < 4; idx++ { + _, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(fmt.Sprintf(`{"messages":[{"role":"user","content":"prompt %d"}]}`, idx)), + }) + require.NoError(t, err) + } + + status, err := svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Len(t, status.PreBlockAPIKeyLoads, 2) + require.Equal(t, int64(4), status.PreBlockAPIKeyTotalCalls) + require.Equal(t, int64(2), status.PreBlockAPIKeyAvailableCount) + require.Equal(t, int64(0), status.PreBlockAPIKeyActive) + require.Equal(t, int64(0), status.PreBlockAPIKeyLoads[0].Active) + require.Equal(t, int64(2), status.PreBlockAPIKeyLoads[0].Total) + require.Equal(t, int64(2), status.PreBlockAPIKeyLoads[0].Success) + require.Equal(t, int64(0), status.PreBlockAPIKeyLoads[0].Errors) + require.Equal(t, int64(2), status.PreBlockAPIKeyLoads[1].Total) + require.Equal(t, int64(2), status.PreBlockAPIKeyLoads[1].Success) +} + +func TestContentModerationStatusTracksPreBlockLocalBlocks(t *testing.T) { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.KeywordBlockingMode = ContentModerationKeywordModeKeywordOnly + cfg.BlockedKeywords = []string{"blocked"} + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + &contentModerationTestRepo{}, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + for _, prompt := range []string{"blocked prompt", "clean prompt"} { + _, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(fmt.Sprintf(`{"messages":[{"role":"user","content":%q}]}`, prompt)), + }) + require.NoError(t, err) + } + + status, err := svc.GetStatus(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(2), status.PreBlockChecked) + require.Equal(t, int64(1), status.PreBlockAllowed) + require.Equal(t, int64(1), status.PreBlockBlocked) + require.Equal(t, int64(0), status.PreBlockErrors) +} + func TestBuildContentModerationTestAuditResult_UsesConfiguredThresholdsOnly(t *testing.T) { result := buildContentModerationTestAuditResult(&moderationAPIResult{ Flagged: true, @@ -1106,6 +1374,8 @@ func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) { cfg.APIKeys = []string{"sk-test"} cfg.BlockStatus = http.StatusConflict cfg.BlockMessage = "命中历史风险输入" + cfg.AutoBanEnabled = true + cfg.BanThreshold = 1 rawCfg, err := json.Marshal(cfg) require.NoError(t, err) @@ -1114,20 +1384,23 @@ func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) { content.Normalize() hashCache.hashes[content.Hash()] = struct{}{} + repo := &contentModerationTestRepo{} + userRepo := &contentModerationTestUserRepo{user: &User{ID: 1001, Status: StatusActive}} svc := NewContentModerationService( &contentModerationTestSettingRepo{values: map[string]string{ SettingKeyRiskControlEnabled: "true", SettingKeyContentModerationConfig: string(rawCfg), }}, - &contentModerationTestRepo{}, + repo, hashCache, nil, - nil, + userRepo, nil, nil, ) decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: 1001, Protocol: ContentModerationProtocolOpenAIChat, Body: []byte(`{"messages":[{"role":"user","content":"blocked prompt"}]}`), }) @@ -1138,7 +1411,73 @@ func TestContentModerationCheck_PreHashUsesRedisHashCache(t *testing.T) { require.Equal(t, content.Hash(), decision.InputHash) require.Contains(t, decision.Message, "命中历史风险输入") require.Contains(t, decision.Message, content.Hash()) - require.Len(t, hashCache.checked, 1) + require.Len(t, hashCache.snapshotChecked(), 1) + logs := requireContentModerationLogCount(t, repo, 1) + require.True(t, logs[0].Flagged) + require.Equal(t, ContentModerationActionHashBlock, logs[0].Action) + require.Equal(t, 1.0, logs[0].CategoryScores["hash"]) + require.Equal(t, ContentModerationModePreBlock, logs[0].Mode) + require.Zero(t, logs[0].ViolationCount) + require.False(t, logs[0].AutoBanned) + require.Empty(t, userRepo.updated) +} + +func TestContentModerationCheck_HashBlockLogsDoNotIncreaseNextViolationCount(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(moderationAPIResponse{ + Results: []moderationAPIResult{{ + CategoryScores: map[string]float64{"sexual": 0.9}, + }}, + }) + })) + defer server.Close() + + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BaseURL = server.URL + cfg.APIKeys = []string{"sk-test"} + cfg.AutoBanEnabled = false + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + + userID := int64(1001) + repo := &contentModerationTestRepo{} + hashLog := &ContentModerationLog{ + UserID: &userID, + Action: ContentModerationActionHashBlock, + Flagged: true, + HighestCategory: "hash", + HighestScore: 1, + CreatedAt: time.Now(), + } + require.NoError(t, repo.CreateLog(context.Background(), hashLog)) + + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + UserID: userID, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"new blocked prompt"}]}`), + }) + + require.NoError(t, err) + require.True(t, decision.Blocked) + logs := requireContentModerationLogCount(t, repo, 2) + require.Equal(t, ContentModerationActionHashBlock, logs[0].Action) + require.Equal(t, ContentModerationActionBlock, logs[1].Action) + require.Equal(t, 1, logs[1].ViolationCount) } func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T) { @@ -1188,8 +1527,8 @@ func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T require.True(t, decision.Blocked) require.Equal(t, ContentModerationActionBlock, decision.Action) require.Equal(t, 1, requestCount) - require.Len(t, hashCache.recorded, 1) - require.Len(t, repo.logs, 1) + recorded := requireRecordedHashCount(t, hashCache, 1) + requireContentModerationLogCount(t, repo, 1) decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ Protocol: ContentModerationProtocolOpenAIChat, @@ -1198,9 +1537,11 @@ func TestContentModerationCheck_PreBlockFlaggedWritesRedisHashCache(t *testing.T require.NoError(t, err) require.True(t, decision.Blocked) require.Equal(t, ContentModerationActionHashBlock, decision.Action) - require.Equal(t, hashCache.recorded[0], decision.InputHash) + require.Equal(t, recorded[0], decision.InputHash) require.Equal(t, 1, requestCount) - require.Len(t, repo.logs, 1) + logs := requireContentModerationLogCount(t, repo, 2) + require.Equal(t, ContentModerationActionBlock, logs[0].Action) + require.Equal(t, ContentModerationActionHashBlock, logs[1].Action) } func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing.T) { @@ -1215,8 +1556,8 @@ func TestContentModerationDeleteFlaggedInputHash_NormalizesAndDeletes(t *testing require.NoError(t, err) require.Equal(t, existingHash, result.InputHash) require.True(t, result.Deleted) - require.NotContains(t, hashCache.hashes, existingHash) - require.Equal(t, []string{existingHash}, hashCache.deleted) + require.False(t, hashCache.hasHash(existingHash)) + require.Equal(t, []string{existingHash}, hashCache.snapshotDeleted()) result, err = svc.DeleteFlaggedInputHash(context.Background(), existingHash) @@ -1296,8 +1637,8 @@ func TestContentModerationCheck_AsyncFlaggedWritesRedisHashCache(t *testing.T) { }, cfg, ContentModerationInput{Text: "bad prompt"}, strings.Repeat("b", 64), contentModerationIntPtr(25), false) require.False(t, decision.Blocked) - require.Len(t, hashCache.recorded, 1) - require.Len(t, repo.logs, 1) + requireRecordedHashCount(t, hashCache, 1) + requireContentModerationLogCount(t, repo, 1) } func TestBuildContentModerationAccountDisabledEmailBody_ContainsBanDetails(t *testing.T) { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2b22e94d..2e33322a 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -1,6 +1,10 @@ package service -import "github.com/Wei-Shaw/sub2api/internal/domain" +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/domain" +) // Status constants const ( @@ -40,6 +44,26 @@ const ( PlatformWindsurf = domain.PlatformWindsurf ) +// AllowedQuotaPlatforms 是允许设置 user × platform quota 的平台列表(单一权威来源)。 +// ent/schema/user_platform_quota.go 的 Validate 函数独立维护(构建期约束), +// 若新增平台需同步修改该 schema。 +var AllowedQuotaPlatforms = []string{ + PlatformAnthropic, + PlatformOpenAI, + PlatformGemini, + PlatformAntigravity, +} + +// IsAllowedQuotaPlatform 报告 s 是否为合法的 quota platform 标识。 +func IsAllowedQuotaPlatform(s string) bool { + for _, p := range AllowedQuotaPlatforms { + if p == s { + return true + } + } + return false +} + // Account type constants const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) @@ -425,5 +449,15 @@ const ( SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置 ) +// SettingKeyDefaultPlatformQuotas —— 系统全局:每用户 × 平台日/周/月 USD 上限(JSON)。 +// 值为 map[platform]{daily,weekly,monthly},null/缺省 = 不限制;0 = 禁用;>0 = USD 上限。 +const SettingKeyDefaultPlatformQuotas = "default_platform_quotas" + +// SettingKeyAuthSourcePlatformQuotas 返回某 auth source 的 platform quota JSON key。 +// 形如 auth_source_default_{source}_platform_quotas +func SettingKeyAuthSourcePlatformQuotas(source string) string { + return fmt.Sprintf("auth_source_default_%s_platform_quotas", source) +} + // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). const AdminAPIKeyPrefix = "admin-" diff --git a/backend/internal/service/domain_constants_test.go b/backend/internal/service/domain_constants_test.go new file mode 100644 index 00000000..2157c979 --- /dev/null +++ b/backend/internal/service/domain_constants_test.go @@ -0,0 +1,23 @@ +//go:build unit + +package service + +import "testing" + +// TestSettingKeyDefaultPlatformQuotas 验证新的系统层 JSON key 常量值正确。 +func TestSettingKeyDefaultPlatformQuotas(t *testing.T) { + if SettingKeyDefaultPlatformQuotas != "default_platform_quotas" { + t.Errorf("SettingKeyDefaultPlatformQuotas = %q, want %q", + SettingKeyDefaultPlatformQuotas, "default_platform_quotas") + } +} + +// TestSettingKeyAuthSourcePlatformQuotas 验证新的 auth-source JSON key 函数返回值正确。 +func TestSettingKeyAuthSourcePlatformQuotas(t *testing.T) { + if got := SettingKeyAuthSourcePlatformQuotas("email"); got != "auth_source_default_email_platform_quotas" { + t.Fatalf("got %q, want %q", got, "auth_source_default_email_platform_quotas") + } + if got := SettingKeyAuthSourcePlatformQuotas("dingtalk"); got != "auth_source_default_dingtalk_platform_quotas" { + t.Fatalf("got %q, want %q", got, "auth_source_default_dingtalk_platform_quotas") + } +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index 7ac77f77..eaf67fab 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -166,7 +166,7 @@ func (s *GatewayService) ForwardAsChatCompletions( Message: upstreamMsg, }) if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel) } return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 8f8a1e94..c55a5a98 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -163,7 +163,7 @@ func (s *GatewayService) ForwardAsResponses( Message: upstreamMsg, }) if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, mappedModel) } return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index e0f5b0c8..e9f84716 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -159,7 +159,7 @@ func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx con func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { +func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { return nil } func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 1c3ace93..f8d98e55 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -44,7 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, - nil, + nil, // userPlatformQuotaRepo ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0e6ce24d..149e475b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -98,6 +98,16 @@ var ( modelsListCacheHitTotal atomic.Int64 modelsListCacheMissTotal atomic.Int64 modelsListCacheStoreTotal atomic.Int64 + + // userPlatformQuotaDBIncrErrorTotal 统计 finalizePostUsageBilling 异步 goroutine + // 中 IncrementUsageWithReset 失败次数。Redis 已成功累加 + DB 写失败意味着 + // Redis cache TTL 过期或被清后该笔 cost 会丢失(与实际消费偏差)。 + // oncall 通过 GatewayUserPlatformQuotaIncrStats() 暴露给 ops 面板做阈值告警。 + userPlatformQuotaDBIncrErrorTotal atomic.Int64 + // userPlatformQuotaDBIncrLegacyErrorTotal 统计 legacy postUsageBilling + // (applyUsageBilling 在 repo==nil 时 fallback)路径下的失败次数; + // 与 DB Incr 失败分开计数,便于区分"主路径暂时故障"vs"基础设施长期未配齐"。 + userPlatformQuotaDBIncrLegacyErrorTotal atomic.Int64 ) func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { @@ -120,6 +130,15 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +// GatewayUserPlatformQuotaIncrStats 返回 (mainPathErr, legacyPathErr)。 +// mainPathErr:finalizePostUsageBilling 异步 goroutine 写 DB 失败累计次数; +// legacyPathErr:postUsageBilling fallback 路径写 DB 失败累计次数。 +// ops 监控面板可以按"持续上升斜率"做告警阈值。 +func GatewayUserPlatformQuotaIncrStats() (mainPathErr, legacyPathErr int64) { + return userPlatformQuotaDBIncrErrorTotal.Load(), + userPlatformQuotaDBIncrLegacyErrorTotal.Load() +} + func openAIStreamEventIsTerminal(data string) bool { trimmed := strings.TrimSpace(data) if trimmed == "" { @@ -597,6 +616,7 @@ type GatewayService struct { debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService balanceNotifyService *BalanceNotifyService + userPlatformQuotaRepo UserPlatformQuotaRepository } // NewGatewayService creates a new GatewayService @@ -628,42 +648,44 @@ func NewGatewayService( resolver *ModelPricingResolver, balanceNotifyService *BalanceNotifyService, rpmTokenBucketSvc *RPMTokenBucketService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) svc := &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - usageBillingRepo: usageBillingRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - digestStore: digestStore, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, - rpmCache: rpmCache, - rpmTokenBucket: rpmTokenBucketSvc, - userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), - settingService: settingService, - modelsListCache: gocache.New(modelsListTTL, time.Minute), - modelsListCacheTTL: modelsListTTL, - responseHeaderFilter: compileResponseHeaderFilter(cfg), - tlsFPProfileService: tlsFPProfileService, - channelService: channelService, - resolver: resolver, - balanceNotifyService: balanceNotifyService, + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + rpmTokenBucket: rpmTokenBucketSvc, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + settingService: settingService, + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + tlsFPProfileService: tlsFPProfileService, + channelService: channelService, + resolver: resolver, + balanceNotifyService: balanceNotifyService, + userPlatformQuotaRepo: userPlatformQuotaRepo, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -4922,7 +4944,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } return s.handleRetryExhaustedError(ctx, resp, c, account) @@ -4938,7 +4960,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(ctx, resp, account, reqModel) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4956,7 +4978,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } if resp.StatusCode >= 400 { @@ -4965,7 +4987,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if readErr != nil { // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, reqModel) } _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) @@ -5001,11 +5023,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } else { logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) } - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(ctx, resp, account, reqModel) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } } - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, reqModel) } // 处理正常响应 @@ -5236,7 +5258,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } return s.handleRetryExhaustedError(ctx, resp, c, account) @@ -5250,7 +5272,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(ctx, resp, account, input.RequestModel) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -5270,12 +5292,12 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } if resp.StatusCode >= 400 { - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, input.RequestModel) } var usage *ClaudeUsage @@ -6009,7 +6031,7 @@ func (s *GatewayService) handleBedrockUpstreamErrors( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } return s.handleRetryExhaustedError(ctx, resp, c, account) @@ -6033,7 +6055,7 @@ func (s *GatewayService) handleBedrockUpstreamErrors( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } @@ -7057,7 +7079,7 @@ func isCountTokensUnsupported404(statusCode int, body []byte) bool { return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") } -func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { +func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, requestedModel ...string) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 调试日志:打印上游错误响应 @@ -7104,7 +7126,11 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res // 处理上游错误,标记账号状态 shouldDisable := false if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + if len(requestedModel) > 0 { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0]) + } else { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } } if shouldDisable { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} @@ -7220,8 +7246,12 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re } } -func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { +func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if len(requestedModel) > 0 { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0]) + return + } s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } @@ -8069,6 +8099,7 @@ type RecordUsageInput struct { RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform) ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } @@ -8098,6 +8129,31 @@ type postUsageBillingParams struct { IsSubscriptionBill bool AccountRateMultiplier float64 APIKeyService APIKeyQuotaUpdater + Platform string // 来自 APIKey 关联 Group 的平台标识 +} + +// PlatformFromAPIKey 从 APIKey 关联的 Group 推导 platform 名称。 +// apiKey 为 nil 或 Group 信息缺失时返回空串(调用方据此 short-circuit quota 累加)。 +// 导出供 handler 层调用。 +func PlatformFromAPIKey(apiKey *APIKey) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return apiKey.Group.Platform +} + +// QuotaPlatform 返回 user×platform 配额计量使用的平台标识。 +// 强制平台路由(如 /antigravity)优先按 ctx 中的 ForcePlatform 计量,否则回退到 +// APIKey 关联 Group 的平台。 +// +// 注意:必须用带 ForcePlatform 的请求 context 调用(如 handler 的 c.Request.Context())。 +// 后扣运行在 worker 池的 background ctx 上没有 ForcePlatform,因此后扣平台由 handler +// 预先算定、经 RecordUsageInput.QuotaPlatform 传入,不要在后扣链路用 worker ctx 调用本函数。 +func QuotaPlatform(ctx context.Context, apiKey *APIKey) string { + if fp, ok := ctx.Value(ctxkey.ForcePlatform).(string); ok && fp != "" { + return fp + } + return PlatformFromAPIKey(apiKey) } func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { @@ -8156,6 +8212,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } } + // Platform quota DB-only 累加(与 finalizePostUsageBilling 行为对齐的兜底): + // - 仅对 standard(余额)模式生效;订阅模式豁免 + // - 直接走 DB,不经 Redis Incr 队列:legacy 路径在 repo==nil(仓库未注入) + // 时被触发,此时整套 billing repo 都不可用,没有"双队列"风险 + // - 失败仅记 ALERT log + counter,不阻断主扣费流程;与正常路径一致 + // + // 历史背景:原 legacy path 完全跳过此累加,导致部署中如果 repo 偶然为 nil + // 时用户消费可绕过 platform quota,存在静默资金风险。 + if !p.IsSubscriptionBill && p.Platform != "" && cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil { + if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(billingCtx, p.User.ID, p.Platform, cost.ActualCost, time.Now().UTC()); err != nil { + userPlatformQuotaDBIncrLegacyErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "ALERT: legacy incr user platform quota DB failed user=%d platform=%s cost=%f: %v", p.User.ID, p.Platform, cost.ActualCost, err) + } + } + // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing // cache updates. The legacy path does DB writes directly; the finalize path // does cache queue + notifications. Notifications are dispatched separately @@ -8279,11 +8350,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog } } - finalizePostUsageBilling(p, deps, result) + finalizePostUsageBilling(billingCtx, p, deps, result) return true, nil } -func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { +func finalizePostUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { if p == nil || p.Cost == nil || deps == nil { return } @@ -8302,6 +8373,32 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + // Platform quota 累加:仅在 standard(余额)模式生效;订阅模式豁免 + // Redis 同步写 + DB 异步持久化: + // - Redis 同步:确保下次 preflight 立即看到最新 usage,把 TOCTOU 超支窗口 + // 限制在并发 in-flight 请求数量内(旧实现的异步入队会让超支无限累积直到 worker 处理) + // - DB 异步:在独立 goroutine 中走 detached context,失败用 ALERT log 触发 oncall 对账 + if !p.IsSubscriptionBill && p.Platform != "" && p.Cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil { + deps.billingCacheService.IncrementUserPlatformQuotaUsage(p.User.ID, p.Platform, p.Cost.ActualCost) + dbCtx, dbCancel := detachUpstreamContext(ctx) + userID, platform, cost := p.User.ID, p.Platform, p.Cost.ActualCost + go func() { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "ALERT: panic in user platform quota incr goroutine user=%d platform=%s: %v", userID, platform, r) + } + }() + defer dbCancel() + if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(dbCtx, userID, platform, cost, time.Now().UTC()); err != nil { + // 失败计数器:暴露给 GatewayUserPlatformQuotaIncrStats(),由 ops 面板做斜率告警。 + userPlatformQuotaDBIncrErrorTotal.Add(1) + // ALERT 级别:DB 持久化失败意味着 Redis cache 失效后该笔 cost 永久丢失, + // 用户配额视图与实际消费会偏差,oncall 需要据此对账或人工补录。 + logger.LegacyPrintf("service.gateway", "ALERT: incr user platform quota DB failed user=%d platform=%s cost=%f: %v", userID, platform, cost, err) + } + }() + } + // Notification checks run async — all parameters are already captured, // no dependency on the request context or upstream connection. go notifyBalanceLow(p, deps, result) @@ -8407,22 +8504,24 @@ func detachUpstreamContext(ctx context.Context) (context.Context, context.Cancel // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) type billingDeps struct { - accountRepo AccountRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - billingCacheService *BillingCacheService - deferredService *DeferredService - balanceNotifyService *BalanceNotifyService + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService + balanceNotifyService *BalanceNotifyService + userPlatformQuotaRepo UserPlatformQuotaRepository } func (s *GatewayService) billingDeps() *billingDeps { return &billingDeps{ - accountRepo: s.accountRepo, - userRepo: s.userRepo, - userSubRepo: s.userSubRepo, - billingCacheService: s.billingCacheService, - deferredService: s.deferredService, - balanceNotifyService: s.balanceNotifyService, + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + balanceNotifyService: s.balanceNotifyService, + userPlatformQuotaRepo: s.userPlatformQuotaRepo, } } @@ -8480,6 +8579,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu RequestPayloadHash: input.RequestPayloadHash, ForceCacheBilling: input.ForceCacheBilling, APIKeyService: input.APIKeyService, + QuotaPlatform: input.QuotaPlatform, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ EnableClaudePath: true, @@ -8502,6 +8602,7 @@ type RecordUsageLongContextInput struct { LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform) ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } @@ -8521,6 +8622,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * RequestPayloadHash: input.RequestPayloadHash, ForceCacheBilling: input.ForceCacheBilling, APIKeyService: input.APIKeyService, + QuotaPlatform: input.QuotaPlatform, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ LongContextThreshold: input.LongContextThreshold, @@ -8542,6 +8644,7 @@ type recordUsageCoreInput struct { RequestPayloadHash string ForceCacheBilling bool APIKeyService APIKeyQuotaUpdater + QuotaPlatform string ChannelUsageFields } @@ -8639,6 +8742,13 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage return nil } + // 配额平台由 handler 在请求 ctx 内经 QuotaPlatform() 算定并通过 input 传入; + // 后扣运行在 worker 池的 background ctx 上,无法再从 ctx 取 ForcePlatform。 + // 缺省(未设置)时回退到分组平台,保持对其它调用方的兼容。 + quotaPlatform := input.QuotaPlatform + if quotaPlatform == "" { + quotaPlatform = PlatformFromAPIKey(apiKey) + } requestID := usageLog.RequestID _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, @@ -8650,6 +8760,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, + Platform: quotaPlatform, }, s.billingDeps(), s.usageBillingRepo) if billingErr != nil { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 815ccb51..8b96c0e8 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -150,7 +150,7 @@ func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx conte func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { +func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { return nil } func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index f6155352..9aa2a52f 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -8,6 +8,7 @@ import ( ) type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig +type GroupModelsListConfig = domain.GroupModelsListConfig type Group struct { ID int64 @@ -61,6 +62,7 @@ type Group struct { RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) DefaultMappedModel string MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + ModelsListConfig GroupModelsListConfig // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。 // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。 diff --git a/backend/internal/service/group_models_list.go b/backend/internal/service/group_models_list.go new file mode 100644 index 00000000..b10de724 --- /dev/null +++ b/backend/internal/service/group_models_list.go @@ -0,0 +1,32 @@ +package service + +import "strings" + +func normalizeGroupModelsListConfig(cfg GroupModelsListConfig) GroupModelsListConfig { + out := GroupModelsListConfig{Enabled: cfg.Enabled} + if len(cfg.Models) == 0 { + return out + } + + seen := make(map[string]struct{}, len(cfg.Models)) + out.Models = make([]string, 0, len(cfg.Models)) + for _, model := range cfg.Models { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + out.Models = append(out.Models, model) + } + if len(out.Models) == 0 { + out.Models = nil + } + return out +} + +func (g *Group) CustomModelsListEnabled() bool { + return g != nil && g.ModelsListConfig.Enabled && len(g.ModelsListConfig.Models) > 0 +} diff --git a/backend/internal/service/http_upstream_profile.go b/backend/internal/service/http_upstream_profile.go new file mode 100644 index 00000000..2d63bbd5 --- /dev/null +++ b/backend/internal/service/http_upstream_profile.go @@ -0,0 +1,42 @@ +package service + +import "context" + +// HTTPUpstreamProfile marks HTTP upstream requests that need provider-specific +// transport policy. +type HTTPUpstreamProfile string + +const ( + HTTPUpstreamProfileDefault HTTPUpstreamProfile = "" + HTTPUpstreamProfileOpenAI HTTPUpstreamProfile = "openai" +) + +type httpUpstreamProfileContextKey struct{} + +// WithHTTPUpstreamProfile injects an upstream transport profile into ctx. +func WithHTTPUpstreamProfile(ctx context.Context, profile HTTPUpstreamProfile) context.Context { + if ctx == nil { + ctx = context.Background() + } + if profile == HTTPUpstreamProfileDefault { + return ctx + } + return context.WithValue(ctx, httpUpstreamProfileContextKey{}, profile) +} + +// HTTPUpstreamProfileFromContext resolves the upstream transport profile from ctx. +func HTTPUpstreamProfileFromContext(ctx context.Context) HTTPUpstreamProfile { + if ctx == nil { + return HTTPUpstreamProfileDefault + } + profile, ok := ctx.Value(httpUpstreamProfileContextKey{}).(HTTPUpstreamProfile) + if !ok { + return HTTPUpstreamProfileDefault + } + switch profile { + case HTTPUpstreamProfileOpenAI: + return profile + default: + return HTTPUpstreamProfileDefault + } +} diff --git a/backend/internal/service/http_upstream_profile_test.go b/backend/internal/service/http_upstream_profile_test.go new file mode 100644 index 00000000..96f0cd31 --- /dev/null +++ b/backend/internal/service/http_upstream_profile_test.go @@ -0,0 +1,21 @@ +package service + +import ( + "context" + "testing" +) + +func TestWithHTTPUpstreamProfile_DefaultKeepsContext(t *testing.T) { + ctx := context.Background() + got := WithHTTPUpstreamProfile(ctx, HTTPUpstreamProfileDefault) + if got != ctx { + t.Fatal("default profile should not wrap context") + } +} + +func TestWithHTTPUpstreamProfile_OpenAI(t *testing.T) { + ctx := WithHTTPUpstreamProfile(context.TODO(), HTTPUpstreamProfileOpenAI) + if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileOpenAI { + t.Fatalf("expected profile %q, got %q", HTTPUpstreamProfileOpenAI, profile) + } +} diff --git a/backend/internal/service/model_not_found_error.go b/backend/internal/service/model_not_found_error.go new file mode 100644 index 00000000..910a97d8 --- /dev/null +++ b/backend/internal/service/model_not_found_error.go @@ -0,0 +1,44 @@ +package service + +import ( + "net/http" + "strings" +) + +var upstreamModelNotFoundKeywords = []string{"model not found", "unknown model", "not found"} + +func isUpstreamModelNotFoundError(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + normalized := normalizeModelNotFoundBody(body) + if normalized == "" || !strings.Contains(normalized, "model") { + return false + } + return containsModelNotFoundKeyword(normalized) +} + +func isModelNotFoundError(statusCode int, body []byte) bool { + return isUpstreamModelNotFoundError(statusCode, body) || statusCode == http.StatusNotFound +} + +func containsModelNotFoundKeyword(normalizedBody string) bool { + if normalizedBody == "" { + return false + } + for _, keyword := range upstreamModelNotFoundKeywords { + if strings.Contains(normalizedBody, keyword) { + return true + } + } + return false +} + +func normalizeModelNotFoundBody(body []byte) string { + if len(body) == 0 { + return "" + } + normalized := strings.ToLower(string(body)) + normalized = strings.NewReplacer("_", " ", "-", " ", "\n", " ", "\r", " ", "\t", " ").Replace(normalized) + return strings.Join(strings.Fields(normalized), " ") +} diff --git a/backend/internal/service/model_not_found_error_test.go b/backend/internal/service/model_not_found_error_test.go new file mode 100644 index 00000000..a87340eb --- /dev/null +++ b/backend/internal/service/model_not_found_error_test.go @@ -0,0 +1,66 @@ +package service + +import ( + "net/http" + "testing" +) + +func TestIsUpstreamModelNotFoundError(t *testing.T) { + tests := []struct { + name string + statusCode int + body []byte + want bool + }{ + { + name: "404 model not found message", + statusCode: http.StatusNotFound, + body: []byte(`{"error":{"message":"model not found"}}`), + want: true, + }, + { + name: "404 model_not_found code", + statusCode: http.StatusNotFound, + body: []byte(`{"error":{"code":"model_not_found","message":"The requested model was not found"}}`), + want: true, + }, + { + name: "404 unknown model message", + statusCode: http.StatusNotFound, + body: []byte(`{"error":{"message":"unknown model gpt-5.4"}}`), + want: true, + }, + { + name: "404 endpoint not found is not model specific", + statusCode: http.StatusNotFound, + body: []byte(`{"error":{"message":"endpoint not found"}}`), + want: false, + }, + { + name: "404 arbitrary body is not model specific", + statusCode: http.StatusNotFound, + body: []byte(`404 page not found`), + want: false, + }, + { + name: "non 404 does not match", + statusCode: http.StatusBadRequest, + body: []byte(`{"error":{"message":"model not found"}}`), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isUpstreamModelNotFoundError(tt.statusCode, tt.body); got != tt.want { + t.Fatalf("isUpstreamModelNotFoundError() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAntigravityModelNotFoundKeepsBare404Fallback(t *testing.T) { + if !isModelNotFoundError(http.StatusNotFound, []byte(`endpoint not found`)) { + t.Fatal("antigravity model-not-found helper should keep bare 404 fallback") + } +} diff --git a/backend/internal/service/openai_account_runtime_block_fastpath.go b/backend/internal/service/openai_account_runtime_block_fastpath.go index 41a309fd..a4e905e8 100644 --- a/backend/internal/service/openai_account_runtime_block_fastpath.go +++ b/backend/internal/service/openai_account_runtime_block_fastpath.go @@ -31,7 +31,7 @@ func isOpenAIAccount(account *Account) bool { return account != nil && account.Platform == PlatformOpenAI } -func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool { +func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) bool { stateCtx, cancel := openAIAccountStateContext(ctx) defer cancel() @@ -41,6 +41,9 @@ func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Cont if s == nil || account == nil || s.rateLimitService == nil { return false } + if len(requestedModel) > 0 && s.rateLimitService.HandleUpstreamModelNotFound(stateCtx, account, requestedModel[0], statusCode, responseBody) { + return true + } shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody) if shouldDisable { s.BlockAccountScheduling(account, time.Time{}, "upstream_disable") diff --git a/backend/internal/service/openai_account_runtime_block_fastpath_test.go b/backend/internal/service/openai_account_runtime_block_fastpath_test.go index 95336e81..3784dd33 100644 --- a/backend/internal/service/openai_account_runtime_block_fastpath_test.go +++ b/backend/internal/service/openai_account_runtime_block_fastpath_test.go @@ -57,6 +57,28 @@ func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account)) } +func TestOpenAIModelNotFound_DoesNotRuntimeBlockWholeAccount(t *testing.T) { + repo := &modelNotFoundAccountRepoStub{} + svc := &OpenAIGatewayService{ + rateLimitService: &RateLimitService{accountRepo: repo}, + } + account := openAIModelNotFoundTempAccount() + + shouldDisable := svc.handleOpenAIAccountUpstreamError( + context.Background(), + account, + http.StatusNotFound, + http.Header{}, + []byte(`{"error":{"code":"model_not_found","message":"model not found"}}`), + "gpt-5.4", + ) + + require.True(t, shouldDisable) + require.False(t, svc.isOpenAIAccountRuntimeBlocked(account)) + require.Zero(t, repo.tempCalls) + require.Len(t, repo.modelRateLimitCalls, 1) +} + func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) { svc := &OpenAIGatewayService{} account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 50acad24..70d5e256 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -545,6 +545,50 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa require.Equal(t, int64(32002), account.ID) } +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_ModelRateLimitOnlySkipsThatModel(t *testing.T) { + ctx := context.Background() + resetAt := time.Now().Add(30 * time.Minute).Format(time.RFC3339) + primary := Account{ + ID: 32101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "gpt-5.4": map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + secondary := Account{ + ID: 32102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + } + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{primary, secondary}}, + cfg: &config.Config{}, + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.4", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(32102), account.ID) + + account, err = svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gpt-5.3", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(32101), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) { ctx := context.Background() groupID := int64(10103) diff --git a/backend/internal/service/openai_apikey_responses_probe.go b/backend/internal/service/openai_apikey_responses_probe.go index a4eb9252..051527f3 100644 --- a/backend/internal/service/openai_apikey_responses_probe.go +++ b/backend/internal/service/openai_apikey_responses_probe.go @@ -95,6 +95,7 @@ func (s *AccountTestService) ProbeOpenAIAPIKeyResponsesSupport(ctx context.Conte logger.LegacyPrintf("service.openai_probe", "probe_build_request_failed: account_id=%d err=%v", accountID, err) return } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Accept", "application/json") diff --git a/backend/internal/service/openai_embeddings.go b/backend/internal/service/openai_embeddings.go new file mode 100644 index 00000000..359df3bb --- /dev/null +++ b/backend/internal/service/openai_embeddings.go @@ -0,0 +1,240 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +func (s *OpenAIGatewayService) ForwardEmbeddings( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + originalModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) + if originalModel == "" { + writeOpenAIEmbeddingsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, fmt.Errorf("missing model in request") + } + + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + upstreamBody := body + if upstreamModel != originalModel { + upstreamBody = ReplaceModelInBody(body, upstreamModel) + } + + logger.L().Debug("openai embeddings: forwarding", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + ) + + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + targetURL := buildOpenAIEmbeddingsURL(validatedURL) + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + upstreamReq.Header.Set("Accept", "application/json") + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiCCRawAllowedHeaders[lowerKey] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), + } + } + writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) + return nil, fmt.Errorf("upstream returned status %d", resp.StatusCode) + } + + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeOpenAIEmbeddingsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + writeOpenAIEmbeddingsUpstreamResponse(c, resp, respBody, s.responseHeaderFilter) + + return &OpenAIForwardResult{ + RequestID: firstNonEmptyString(resp.Header.Get("x-request-id"), resp.Header.Get("request-id")), + Usage: extractOpenAIEmbeddingsUsage(respBody), + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +func writeOpenAIEmbeddingsUpstreamResponse(c *gin.Context, resp *http.Response, body []byte, filter *responseheaders.CompiledHeaderFilter) { + if c == nil || resp == nil { + return + } + if c.Writer.Written() { + return + } + if resp.Header != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, filter) + } + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Writer.Header().Set("Content-Type", ct) + } else { + c.Writer.Header().Set("Content-Type", "application/json") + } + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(body) +} + +func writeOpenAIEmbeddingsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func extractOpenAIEmbeddingsUsage(body []byte) OpenAIUsage { + usage := gjson.GetBytes(body, "usage") + if !usage.Exists() || !usage.IsObject() { + return OpenAIUsage{} + } + inputTokens := firstPositiveGJSONInt( + usage.Get("prompt_tokens"), + usage.Get("input_tokens"), + usage.Get("total_tokens"), + ) + outputTokens := firstPositiveGJSONInt( + usage.Get("completion_tokens"), + usage.Get("output_tokens"), + ) + cacheReadTokens := firstPositiveGJSONInt( + usage.Get("prompt_tokens_details.cached_tokens"), + usage.Get("input_tokens_details.cached_tokens"), + usage.Get("cache_read_tokens"), + usage.Get("cache_read_input_tokens"), + ) + cacheCreationTokens := firstPositiveGJSONInt( + usage.Get("cache_creation_tokens"), + usage.Get("cache_creation_input_tokens"), + usage.Get("input_tokens_details.cache_creation_tokens"), + ) + return OpenAIUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cacheReadTokens, + CacheCreationInputTokens: cacheCreationTokens, + } +} + +func firstPositiveGJSONInt(values ...gjson.Result) int { + for _, value := range values { + if !value.Exists() { + continue + } + n := int(value.Int()) + if n > 0 { + return n + } + } + return 0 +} + +func buildOpenAIEmbeddingsURL(base string) string { + return buildOpenAIEndpointURL(base, "/v1/embeddings") +} diff --git a/backend/internal/service/openai_embeddings_test.go b/backend/internal/service/openai_embeddings_test.go new file mode 100644 index 00000000..c7e89d64 --- /dev/null +++ b/backend/internal/service/openai_embeddings_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildOpenAIEmbeddingsURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + want string + }{ + {"bare domain", "https://api.openai.com", "https://api.openai.com/v1/embeddings"}, + {"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/embeddings"}, + {"already embeddings", "https://api.openai.com/v1/embeddings", "https://api.openai.com/v1/embeddings"}, + {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/embeddings"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, buildOpenAIEmbeddingsURL(tt.base)) + }) + } +} + +func TestForwardEmbeddings_APIKeyPassthroughRecordsUsageAndBatchInput(t *testing.T) { + gin.SetMode(gin.TestMode) + + reqBody := []byte(`{ + "model":"nowledge-embedding", + "input":["hello","world"], + "encoding_format":"float", + "dimensions":256 + }`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/embeddings", bytes.NewReader(reqBody)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-Request-Id": []string{"emb-rid"}, + }, + Body: io.NopCloser(strings.NewReader(`{ + "object":"list", + "data":[ + {"object":"embedding","index":0,"embedding":[0.1,0.2]}, + {"object":"embedding","index":1,"embedding":[0.3,0.4]} + ], + "model":"jina-embeddings-v5-text-small", + "usage":{"prompt_tokens":13,"total_tokens":13} + }`)), + }} + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + } + account := &Account{ + ID: 42, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.jina.ai", + "model_mapping": map[string]any{ + "nowledge-embedding": "jina-embeddings-v5-text-small", + }, + }, + } + + result, err := svc.ForwardEmbeddings(context.Background(), c, account, reqBody, "") + + require.NoError(t, err) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, result) + require.Equal(t, "emb-rid", result.RequestID) + require.Equal(t, "nowledge-embedding", result.Model) + require.Equal(t, "jina-embeddings-v5-text-small", result.BillingModel) + require.Equal(t, "jina-embeddings-v5-text-small", result.UpstreamModel) + require.Equal(t, 13, result.Usage.InputTokens) + require.Equal(t, 0, result.Usage.OutputTokens) + require.Equal(t, "https://api.jina.ai/v1/embeddings", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "jina-embeddings-v5-text-small", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, int64(2), gjson.GetBytes(upstream.lastBody, "input.#").Int()) + require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "input.0").String()) + require.Equal(t, "world", gjson.GetBytes(upstream.lastBody, "input.1").String()) + require.Equal(t, "float", gjson.GetBytes(upstream.lastBody, "encoding_format").String()) + require.Equal(t, int64(256), gjson.GetBytes(upstream.lastBody, "dimensions").Int()) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 27eb211e..807ff43a 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -76,7 +76,6 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } originalModel := chatReq.Model clientStream := chatReq.Stream - includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. @@ -193,6 +192,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( if policyErr != nil { var blocked *OpenAIFastBlockedError if errors.As(policyErr, &blocked) { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) } return nil, policyErr @@ -276,21 +276,21 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } - return s.handleChatCompletionsErrorResponse(resp, c, account) + return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel) } // 9. Handle normal response var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, account, originalModel, billingModel, upstreamModel, includeUsage, startTime, len(body)) + result, handleErr = s.handleChatStreamingResponse(resp, c, account, originalModel, billingModel, upstreamModel, startTime, len(body)) } else { result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } @@ -358,8 +358,9 @@ func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( resp *http.Response, c *gin.Context, account *Account, + requestedModel ...string, ) (*OpenAIForwardResult, error) { - return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) + return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError, requestedModel...) } // handleChatBufferedStreamingResponse reads all Responses SSE events from the @@ -416,7 +417,6 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( originalModel string, billingModel string, upstreamModel string, - includeUsage bool, startTime time.Time, requestBodyLen int, ) (*OpenAIForwardResult, error) { @@ -440,7 +440,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( state := apicompat.NewResponsesEventToChatState() state.Model = originalModel - state.IncludeUsage = includeUsage + // 网关作为计费链路的一环,不能把下游 usage 输出绑定到客户端是否显式请求。 + // raw Chat Completions 直转路径已经强制透出 usage,这里保持同样行为,避免级联代理计费为 0。 + state.IncludeUsage = true var usage OpenAIUsage var firstTokenMs *int @@ -501,10 +503,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } refusalDetector.ObservePayload([]byte(payload)) - // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) - if isTerminalEvent && event.Response != nil && event.Response.Usage != nil { - usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + if isTerminalEvent { + if event.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Usage) + } + if event.Response != nil && event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } } chunks := apicompat.ResponsesEventToChatChunks(&event, state) diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index ad6d3e8d..e351fa75 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -93,6 +93,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( if policyErr != nil { var blocked *OpenAIFastBlockedError if errors.As(policyErr, &blocked) { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) } return nil, policyErr @@ -135,6 +136,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } + upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) if clientStream { @@ -206,14 +208,14 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } - return s.handleChatCompletionsErrorResponse(resp, c, account) + return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel) } // 8. Forward response diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go index 64449636..eeb814f1 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go @@ -116,6 +116,7 @@ func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDown require.Equal(t, 3, result.Usage.CacheReadInputTokens) require.NotNil(t, upstream.lastReq) require.NoError(t, upstream.lastReq.Context().Err()) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) require.Contains(t, rec.Body.String(), `"usage"`) require.Contains(t, rec.Body.String(), "data: [DONE]") @@ -327,7 +328,6 @@ func TestHandleChatStreamingResponse_SilentRefusalReasoningSummaryExempt(t *test "gpt-5.5", "gpt-5.5", "gpt-5.5", - false, time.Now(), openAISilentRefusalMinRequestBodyBytes, ) diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index a26091a3..9a5ea711 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -180,6 +180,158 @@ func TestForwardAsChatCompletions_ClientDisconnectDrainsUpstreamUsage(t *testing require.Equal(t, 4, result.Usage.CacheReadInputTokens) } +func TestForwardAsChatCompletions_StreamsUsageWithoutClientStreamOptions(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":13,"output_tokens":7,"total_tokens":20,"input_tokens_details":{"cached_tokens":5}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_usage_no_stream_options"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 13, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.CacheReadInputTokens) + + responseBody := rec.Body.String() + require.Contains(t, responseBody, `"usage"`) + require.Contains(t, responseBody, `"prompt_tokens":13`) + require.Contains(t, responseBody, `"completion_tokens":7`) + require.Contains(t, responseBody, `"cached_tokens":5`) +} + +func TestForwardAsChatCompletions_StreamsTopLevelTerminalUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_top","model":"gpt-5.4","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_top","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}]},"usage":{"input_tokens":21,"output_tokens":9,"total_tokens":30,"input_tokens_details":{"cached_tokens":4}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_top_level_usage"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 21, result.Usage.InputTokens) + require.Equal(t, 9, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) + + responseBody := rec.Body.String() + require.Contains(t, responseBody, `"usage"`) + require.Contains(t, responseBody, `"prompt_tokens":21`) + require.Contains(t, responseBody, `"completion_tokens":9`) + require.Contains(t, responseBody, `"cached_tokens":4`) +} + +func TestForwardAsChatCompletions_BufferedTopLevelTerminalUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_top_buffered","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}]},"usage":{"input_tokens":18,"output_tokens":6,"total_tokens":24,"input_tokens_details":{"cached_tokens":3}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_buffered_top_level_usage"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 18, result.Usage.InputTokens) + require.Equal(t, 6, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + + responseBody := rec.Body.String() + require.Contains(t, responseBody, `"usage"`) + require.Contains(t, responseBody, `"prompt_tokens":18`) + require.Contains(t, responseBody, `"completion_tokens":6`) + require.Contains(t, responseBody, `"cached_tokens":3`) +} + func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 336a7d79..291c217e 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -231,6 +231,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( if policyErr != nil { var blocked *OpenAIFastBlockedError if errors.As(policyErr, &blocked) { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message) } return nil, policyErr @@ -337,15 +338,15 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( Message: upstreamMsg, Detail: upstreamDetail, }) - s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } // Non-failover error: return Anthropic-formatted error to client - return s.handleAnthropicErrorResponse(resp, c, account) + return s.handleAnthropicErrorResponse(resp, c, account, billingModel) } if account.Type == AccountTypeOAuth && promptCacheKey != "" { @@ -412,8 +413,9 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse( resp *http.Response, c *gin.Context, account *Account, + requestedModel ...string, ) (*OpenAIForwardResult, error) { - return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) + return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError, requestedModel...) } // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from @@ -569,6 +571,12 @@ func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( if err := json.Unmarshal([]byte(payload), &event); err == nil { acc.ProcessEvent(&event) if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Usage) + if event.Response.Usage == nil { + event.Response.Usage = event.Usage + } + } if event.Response.Usage != nil { usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } @@ -610,6 +618,12 @@ func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal( acc.ProcessEvent(&event) if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil { + if event.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Usage) + if event.Response.Usage == nil { + event.Response.Usage = event.Usage + } + } if event.Response.Usage != nil { usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) } @@ -712,14 +726,18 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( return false } - // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) - if isTerminalEvent && event.Response != nil { - if id := strings.TrimSpace(event.Response.ID); id != "" { - responseID = id + if isTerminalEvent { + if event.Response != nil { + if id := strings.TrimSpace(event.Response.ID); id != "" { + responseID = id + } + if event.Response.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + } } - if event.Response.Usage != nil { - usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage) + if event.Usage != nil { + usage = copyOpenAIUsageFromResponsesUsage(event.Usage) } } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 096f5b10..9769a82e 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -155,6 +155,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, nil, nil, + nil, // userPlatformQuotaRepo ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback.go b/backend/internal/service/openai_gateway_responses_chat_fallback.go index c3ebc35c..cfab389a 100644 --- a/backend/internal/service/openai_gateway_responses_chat_fallback.go +++ b/backend/internal/service/openai_gateway_responses_chat_fallback.go @@ -116,6 +116,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } + upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) if clientStream { @@ -187,14 +188,14 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } - return s.handleErrorResponse(ctx, resp, c, account, chatBody) + return s.handleErrorResponse(ctx, resp, c, account, chatBody, billingModel) } if clientStream { diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback_test.go b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go index 78df2202..abb645e8 100644 --- a/backend/internal/service/openai_gateway_responses_chat_fallback_test.go +++ b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go @@ -42,6 +42,7 @@ func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletion require.NoError(t, err) require.NotNil(t, result) require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String()) require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String()) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f312f50d..f93cc221 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -343,6 +343,7 @@ type OpenAIGatewayService struct { channelService *ChannelService balanceNotifyService *BalanceNotifyService settingService *SettingService + userPlatformQuotaRepo UserPlatformQuotaRepository openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -387,6 +388,7 @@ func NewOpenAIGatewayService( channelService *ChannelService, balanceNotifyService *BalanceNotifyService, settingService *SettingService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -418,6 +420,7 @@ func NewOpenAIGatewayService( channelService: channelService, balanceNotifyService: balanceNotifyService, settingService: settingService, + userPlatformQuotaRepo: userPlatformQuotaRepo, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -523,12 +526,13 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ - accountRepo: s.accountRepo, - userRepo: s.userRepo, - userSubRepo: s.userSubRepo, - billingCacheService: s.billingCacheService, - deferredService: s.deferredService, - balanceNotifyService: s.balanceNotifyService, + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + balanceNotifyService: s.balanceNotifyService, + userPlatformQuotaRepo: s.userPlatformQuotaRepo, } } @@ -1308,8 +1312,8 @@ func openAICompactSupportTier(account *Account) int { // isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / // compact-support checks used during account selection. -func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool { - if account == nil || !account.IsSchedulable() || !account.IsOpenAI() { +func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool { + if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) { return false } if requestedModel != "" && !account.IsModelSupported(requestedModel) { @@ -1442,7 +1446,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 验证账号是否可用于当前请求 // Verify account is usable for current request - if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { return nil } if s.isOpenAIAccountRuntimeBlocked(account) { @@ -1642,7 +1646,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex if clearSticky { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } - if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) { + if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) @@ -1920,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. fresh = current } - if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) { return nil } if s.isOpenAIAccountRuntimeBlocked(fresh) { @@ -1934,7 +1938,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co return nil } if s.schedulerSnapshot == nil || s.accountRepo == nil { - if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) { return nil } return account @@ -1944,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if err != nil || latest == nil { return nil } - if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) { return nil } if s.isOpenAIAccountRuntimeBlocked(latest) { @@ -2063,8 +2067,12 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) } -func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { +func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account, requestedModel ...string) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if len(requestedModel) > 0 { + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, requestedModel[0]) + return + } s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } @@ -2076,6 +2084,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco apiKeyID := getAPIKeyIDFromContext(c) logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) if restrictionResult.Enabled && !restrictionResult.Matched { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "type": "forbidden_error", @@ -2119,6 +2128,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { if c != nil { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusBadRequest, gin.H{ "error": gin.H{ "type": "invalid_request_error", @@ -2159,7 +2169,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey) if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "type": "permission_error", @@ -2488,7 +2498,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "type": "permission_error", @@ -2846,14 +2856,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Detail: upstreamDetail, }) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(ctx, resp, account, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } - return s.handleErrorResponse(ctx, resp, c, account, body) + return s.handleErrorResponse(ctx, resp, c, account, body, billingModel) } defer func() { _ = resp.Body.Close() }() @@ -2945,17 +2955,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( if account != nil && account.Type == AccountTypeOAuth { if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" - setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: http.StatusForbidden, - Passthrough: true, - Kind: "request_error", - Message: rejectMsg, - Detail: rejectReason, - }) + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ @@ -3006,7 +3006,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( apiKey := getAPIKeyFromContext(c) if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalFeatureGate) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "type": "permission_error", @@ -3229,6 +3229,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // 透传客户端请求头(安全白名单)。 allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() @@ -3347,7 +3348,8 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody) + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -3390,7 +3392,8 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) // 透传模式保留原始上游错误响应,但运行态账号状态仍需更新, // 避免粘性路由继续复用刚被限流的账号。 - _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + reqModel, _, _ := extractOpenAIRequestMetaFromBody(requestBody) + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -3951,6 +3954,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // Set authentication header req.Header.Set("authorization", "Bearer "+token) @@ -4063,6 +4067,7 @@ func (s *OpenAIGatewayService) handleErrorResponse( c *gin.Context, account *Account, requestBody []byte, + requestedModel ...string, ) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -4139,7 +4144,14 @@ func (s *OpenAIGatewayService) handleErrorResponse( } // Handle upstream error (mark account status) - shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + var reqModel string + if len(requestedModel) > 0 { + reqModel = strings.TrimSpace(requestedModel[0]) + } + if reqModel == "" { + reqModel, _, _ = extractOpenAIRequestMetaFromBody(requestBody) + } + shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body, reqModel) kind := "http_error" if shouldDisable { kind = "failover" @@ -4158,7 +4170,7 @@ func (s *OpenAIGatewayService) handleErrorResponse( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: body, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } @@ -4216,6 +4228,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( c *gin.Context, account *Account, writeError compatErrorWriter, + requestedModel ...string, ) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -4271,8 +4284,12 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( } // Track rate limits and decide whether to trigger secondary failover. + var modelForCooldown string + if len(requestedModel) > 0 { + modelForCooldown = requestedModel[0] + } shouldDisable := s.handleOpenAIAccountUpstreamError( - c.Request.Context(), account, resp.StatusCode, resp.Header, body, + c.Request.Context(), account, resp.StatusCode, resp.Header, body, modelForCooldown, ) kind := "http_error" if shouldDisable { @@ -4292,7 +4309,7 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: body, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } @@ -4905,20 +4922,22 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r if isEventStreamResponse(resp.Header) { return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) } + bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) + // For OAuth accounts, also fall back to a body-content heuristic because // the upstream may omit the Content-Type header while still sending SSE. // This heuristic is NOT applied to API-key accounts to avoid false // positives on JSON responses that coincidentally contain "data:" or // "event:" in their text content. - if account.Type == AccountTypeOAuth { - bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) - if bodyLooksLikeSSE { - return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) - } + if account.Type == AccountTypeOAuth && bodyLooksLikeSSE { + return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) } usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) if !usageOK { + if bodyLooksLikeSSE { + return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) + } return nil, fmt.Errorf("parse response: invalid json response") } usage := &usageValue @@ -5634,6 +5653,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, + Platform: PlatformFromAPIKey(apiKey), }, s.billingDeps(), s.usageBillingRepo) return err }() @@ -6315,6 +6335,7 @@ func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlocked if c == nil || err == nil { return } + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "type": "permission_error", diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 0fac4508..e8354837 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1865,6 +1865,7 @@ func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *test require.Equal(t, "application/json", req.Header.Get("Accept")) require.Equal(t, codexCLIVersion, req.Header.Get("Version")) require.NotEmpty(t, req.Header.Get("Session_Id")) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) } func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) { @@ -1885,6 +1886,7 @@ func TestOpenAIBuildUpstreamRequestCompactForcesJSONAcceptForOAuth(t *testing.T) require.Equal(t, "application/json", req.Header.Get("Accept")) require.Equal(t, codexCLIVersion, req.Header.Get("Version")) require.NotEmpty(t, req.Header.Get("Session_Id")) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) } func TestOpenAIBuildUpstreamRequestOAuthMessagesBridgeUsesSessionOnly(t *testing.T) { @@ -2303,6 +2305,35 @@ func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { require.NotContains(t, rec.Body.String(), "data:") } +func TestHandleNonStreamingResponse_APIKeyFallsBackToSSEBodyWhenContentTypeIsWrong(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"hel"}`, + `data: {"type":"response.output_text.delta","delta":"lo"}`, + `data: {"type":"response.completed","response":{"id":"resp_api_key_sse","object":"response","model":"gpt-5.4","status":"completed","output":[],"usage":{"input_tokens":3,"output_tokens":2,"total_tokens":5}}}`, + `data: [DONE]`, + }, "\n"))), + } + account := &Account{ID: 1, Type: AccountTypeAPIKey} + + result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, account, "gpt-5.4", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, result.InputTokens) + require.Equal(t, 2, result.OutputTokens) + require.NotContains(t, rec.Body.String(), "data:") + require.Equal(t, "resp_api_key_sse", gjson.Get(rec.Body.String(), "id").String()) + require.Equal(t, "hello", gjson.Get(rec.Body.String(), "output.0.content.0.text").String()) +} + func TestHandleSSEToJSON_ReconstructsImageGenerationOutputItemDone(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 95c054c9..1bcd947c 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -638,11 +638,11 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(upstreamCtx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody) @@ -743,6 +743,7 @@ func (s *OpenAIGatewayService) buildOpenAIImagesRequest( if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Authorization", "Bearer "+token) for key, values := range c.Request.Header { if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] { diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index b39fa609..849ad792 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -1188,11 +1188,11 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(upstreamCtx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account, requestModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + RetryableOnSameAccount: account.IsPoolMode() && account.IsPoolModeRetryableStatus(resp.StatusCode), } } return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody) diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 52903a1b..854e9f6d 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -528,6 +528,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *te require.NotNil(t, upstream.lastReq) require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String()) require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(upstream.lastReq.Context())) require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id")) diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index c0f98de4..7d503f5a 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -20,6 +20,32 @@ type FunctionCallOutputValidation struct { HasItemReferenceForAllCallIDs bool } +func isCodexToolCallContextItemType(typ string) bool { + switch strings.TrimSpace(typ) { + case "tool_call", + "function_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "mcp_tool_call": + return true + default: + return false + } +} + +func isCodexToolCallOutputItemType(typ string) bool { + switch strings.TrimSpace(typ) { + case "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output": + return true + default: + return false + } +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、 // 或显式声明 tools/tool_choice。 @@ -53,7 +79,9 @@ func NeedsToolContinuation(reqBody map[string]any) bool { return false } -// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +// AnalyzeToolContinuationSignals 单次遍历 input,提取工具输出/工具调用上下文/item_reference 相关信号。 +// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖 Codex 的所有工具输出 +// (function_call_output/tool_search_output/custom_tool_call_output/mcp_tool_call_output)。 func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { signals := ToolContinuationSignals{} if reqBody == nil { @@ -73,13 +101,13 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "tool_call", "function_call": + switch { + case isCodexToolCallContextItemType(itemType): callID, _ := itemMap["call_id"].(string) if strings.TrimSpace(callID) != "" { signals.HasToolCallContext = true } - case "function_call_output": + case isCodexToolCallOutputItemType(itemType): signals.HasFunctionCallOutput = true callID, _ := itemMap["call_id"].(string) callID = strings.TrimSpace(callID) @@ -91,7 +119,7 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign callIDs = make(map[string]struct{}) } callIDs[callID] = struct{}{} - case "item_reference": + case itemType == "item_reference": signals.HasItemReference = true idValue, _ := itemMap["id"].(string) idValue = strings.TrimSpace(idValue) @@ -123,9 +151,10 @@ func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSign } // ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: -// 1) 无 function_call_output 直接返回 -// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 1) 无工具输出直接返回 +// 2) 若已存在工具调用上下文则提前返回 // 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +// 字段名保留 FunctionCallOutput 是为了兼容既有调用点;语义覆盖所有 Codex 工具输出。 func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { result := FunctionCallOutputValidation{} if reqBody == nil { @@ -142,10 +171,10 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "function_call_output": + switch { + case isCodexToolCallOutputItemType(itemType): result.HasFunctionCallOutput = true - case "tool_call", "function_call": + case isCodexToolCallContextItemType(itemType): callID, _ := itemMap["call_id"].(string) if strings.TrimSpace(callID) != "" { result.HasToolCallContext = true @@ -168,8 +197,8 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } itemType, _ := itemMap["type"].(string) - switch itemType { - case "function_call_output": + switch { + case isCodexToolCallOutputItemType(itemType): callID, _ := itemMap["call_id"].(string) callID = strings.TrimSpace(callID) if callID == "" { @@ -177,7 +206,7 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu continue } callIDs[callID] = struct{}{} - case "item_reference": + case itemType == "item_reference": idValue, _ := itemMap["id"].(string) idValue = strings.TrimSpace(idValue) if idValue == "" { @@ -201,24 +230,25 @@ func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutpu return result } -// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 +// HasFunctionCallOutput 判断 input 是否包含任意 Codex 工具输出,用于触发续链校验。 +// 名称保留 function_call_output 是为了兼容既有调用点。 func HasFunctionCallOutput(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput } -// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, -// 用于判断 function_call_output 是否具备可关联的上下文。 +// HasToolCallContext 判断 input 是否包含带 call_id 的工具调用上下文, +// 用于判断工具输出是否具备可关联的上下文。 func HasToolCallContext(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext } -// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 +// FunctionCallOutputCallIDs 提取 input 中工具输出的 call_id 集合。 // 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 func FunctionCallOutputCallIDs(reqBody map[string]any) []string { return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs } -// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 +// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的工具输出。 func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go index 3f415d9d..0e0552f6 100644 --- a/backend/internal/service/openai_tool_continuation_test.go +++ b/backend/internal/service/openai_tool_continuation_test.go @@ -38,41 +38,57 @@ func TestNeedsToolContinuationSignals(t *testing.T) { } func TestHasFunctionCallOutput(t *testing.T) { - // 仅当 input 中存在 function_call_output 才视为续链输出。 + // 所有 Codex 工具输出都应视为续链输出,避免 WS 续链时丢失 previous_response_id。 require.False(t, HasFunctionCallOutput(nil)) - require.True(t, HasFunctionCallOutput(map[string]any{ - "input": []any{map[string]any{"type": "function_call_output"}}, - })) + for _, typ := range []string{ + "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output", + } { + require.True(t, HasFunctionCallOutput(map[string]any{ + "input": []any{map[string]any{"type": typ}}, + }), typ) + } require.False(t, HasFunctionCallOutput(map[string]any{ "input": "text", })) } func TestHasToolCallContext(t *testing.T) { - // tool_call/function_call 必须包含 call_id,才能作为可关联上下文。 + // 工具调用上下文必须包含 call_id,才能作为可关联上下文。 require.False(t, HasToolCallContext(nil)) - require.True(t, HasToolCallContext(map[string]any{ - "input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}}, - })) - require.True(t, HasToolCallContext(map[string]any{ - "input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}}, - })) + for _, typ := range []string{ + "tool_call", + "function_call", + "local_shell_call", + "tool_search_call", + "custom_tool_call", + "mcp_tool_call", + } { + require.True(t, HasToolCallContext(map[string]any{ + "input": []any{map[string]any{"type": typ, "call_id": "call_1"}}, + }), typ) + } require.False(t, HasToolCallContext(map[string]any{ "input": []any{map[string]any{"type": "tool_call"}}, })) } func TestFunctionCallOutputCallIDs(t *testing.T) { - // 仅提取非空 call_id,去重后返回。 + // 仅提取工具输出的非空 call_id,去重后返回。 require.Empty(t, FunctionCallOutputCallIDs(nil)) callIDs := FunctionCallOutputCallIDs(map[string]any{ "input": []any{ map[string]any{"type": "function_call_output", "call_id": "call_1"}, + map[string]any{"type": "tool_search_output", "call_id": "call_search"}, + map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom"}, + map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp"}, map[string]any{"type": "function_call_output", "call_id": ""}, map[string]any{"type": "function_call_output", "call_id": "call_1"}, }, }) - require.ElementsMatch(t, []string{"call_1"}, callIDs) + require.ElementsMatch(t, []string{"call_1", "call_search", "call_custom", "call_mcp"}, callIDs) } func TestHasFunctionCallOutputMissingCallID(t *testing.T) { @@ -80,8 +96,11 @@ func TestHasFunctionCallOutputMissingCallID(t *testing.T) { require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ "input": []any{map[string]any{"type": "function_call_output"}}, })) + require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{ + "input": []any{map[string]any{"type": "tool_search_output"}}, + })) require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{ - "input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}}, + "input": []any{map[string]any{"type": "tool_search_output", "call_id": "call_1"}}, })) } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 700dbedf..b8e558ae 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1548,13 +1548,35 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { for _, item := range items { - if gjson.GetBytes(item, "type").String() == "function_call_output" { + if isCodexToolCallOutputItemType(gjson.GetBytes(item, "type").String()) { return true } } return false } +func openAIWSRawPayloadHasToolCallOutput(payload []byte) bool { + if len(payload) == 0 { + return false + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return false + } + if input.IsArray() { + for _, item := range input.Array() { + if isCodexToolCallOutputItemType(item.Get("type").String()) { + return true + } + } + return false + } + if input.Type == gjson.JSON { + return isCodexToolCallOutputItemType(input.Get("type").String()) + } + return false +} + func buildOpenAIWSReplayInputSequence( previousFullInput []json.RawMessage, previousFullInputExists bool, @@ -2590,6 +2612,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr) } if blocked != nil { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) // Send a Realtime-style error event to the client first, then // signal the handler to close the connection with PolicyViolation. // We intentionally do NOT forward this frame upstream. @@ -2759,6 +2782,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( var dialErr *openAIWSDialError if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests { s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error())) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseHeaders: cloneHeader(dialErr.ResponseHeaders), + } } if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { return nil, NewOpenAIWSClientCloseError( @@ -2855,7 +2882,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) - turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + turnHasFunctionCallOutput := openAIWSRawPayloadHasToolCallOutput(payload) eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 @@ -2954,6 +2981,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( false, ) } + if !wroteDownstream && isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) { + lease.MarkBroken() + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseBody: append([]byte(nil), upstreamMessage...), + ResponseHeaders: cloneHeader(lease.HandshakeHeaders()), + } + } } isTokenEvent := isOpenAIWSTokenEvent(eventType) if isTokenEvent { @@ -3131,7 +3166,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentTurnReplayInputExists := false skipBeforeTurn := false hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool { - if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() { + if openAIWSRawPayloadHasToolCallOutput(payload) { return true } return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput) @@ -3256,7 +3291,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) toolSignals := ToolContinuationSignals{ - HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(), + HasFunctionCallOutput: openAIWSRawPayloadHasToolCallOutput(currentPayload), } if toolSignals.HasFunctionCallOutput { var currentReqBody map[string]any diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index a4b39ddf..edb6fbcd 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -1223,6 +1223,141 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") } +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledToolSearchOutputAutoAttachesPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_tool_search_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 145, + Name: "openai-ingress-tool-search-output-auto-prev", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_tool_search_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"tool_search_output","call_id":"call_search_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_tool_search_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.Equal(t, "resp_tool_search_prev_1", gjson.Get(secondWrite, "previous_response_id").String(), "tool_search_output 缺失 previous_response_id 时应回填上一轮响应 ID") + require.Equal(t, "tool_search_output", gjson.Get(secondWrite, "input.0.type").String()) + require.Equal(t, "call_search_1", gjson.Get(secondWrite, "input.0.call_id").String()) +} + func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index c735f50a..31c9a142 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -696,6 +696,36 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { }) } +func TestOpenAIWSRawPayloadHasToolCallOutput(t *testing.T) { + t.Parallel() + + for _, typ := range []string{ + "function_call_output", + "tool_search_output", + "custom_tool_call_output", + "mcp_tool_call_output", + } { + typ := typ + t.Run(typ, func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":[{"type":"` + typ + `","call_id":"call_1","output":"ok"}]}`) + require.True(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) + } + + t.Run("object_input", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":{"type":"tool_search_output","call_id":"call_1","output":"ok"}}`) + require.True(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) + + t.Run("non_tool_output", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"input":[{"type":"input_text","text":"hello"}]}`) + require.False(t, openAIWSRawPayloadHasToolCallOutput(payload)) + }) +} + func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index f3936de1..99e92251 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -619,6 +619,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, // userPlatformQuotaRepo ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 4ee85a3a..a3673d74 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -338,6 +338,9 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ErrorEventUsageL select { case serverErr := <-serverErrCh: require.Error(t, serverErr) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, serverErr, &failoverErr) + require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode) require.Len(t, repo.rateLimitCalls, 1) require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) case <-time.After(5 * time.Second): diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go index 2b7e2add..35c7569d 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -55,14 +55,18 @@ type RelayExit struct { } type RelayOptions struct { - WriteTimeout time.Duration - IdleTimeout time.Duration - UpstreamDrainTimeout time.Duration - FirstMessageType coderws.MessageType - OnUsageParseFailure func(eventType string, usageRaw string) - OnTurnComplete func(turn RelayTurnResult) - OnTrace func(event RelayTraceEvent) - Now func() time.Time + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + FirstMessageSent bool + StartClientAfterFirstDownstream bool + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + BeforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error + ReadClientFrame func(ctx context.Context, clientConn FrameConn) (coderws.MessageType, []byte, error) + OnTrace func(event RelayTraceEvent) + Now func() time.Time } type RelayTraceEvent struct { @@ -170,29 +174,47 @@ func Relay( MessageType: relayMessageTypeString(firstMessageType), }) - if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { - result.Duration = nowFn().Sub(startAt) + if options.FirstMessageSent { emitRelayTrace(onTrace, RelayTraceEvent{ - Stage: "write_first_message_failed", + Stage: "write_first_message_skipped", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + } else { + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", Direction: "client_to_upstream", MessageType: relayMessageTypeString(firstMessageType), PayloadBytes: len(firstClientMessage), - Error: err.Error(), }) - return result, &RelayExit{Stage: "write_upstream", Err: err} } clientToUpstreamFrames.Add(1) - emitRelayTrace(onTrace, RelayTraceEvent{ - Stage: "write_first_message_ok", - Direction: "client_to_upstream", - MessageType: relayMessageTypeString(firstMessageType), - PayloadBytes: len(firstClientMessage), - }) markActivity() exitCh := make(chan relayExitSignal, 3) dropDownstreamWrites := atomic.Bool{} - go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + clientReaderStarted := atomic.Bool{} + startClientReader := func() { + if !clientReaderStarted.CompareAndSwap(false, true) { + return + } + go runClientToUpstream(relayCtx, clientConn, options.ReadClientFrame, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + } + if !options.StartClientAfterFirstDownstream { + startClientReader() + } go runUpstreamToClient( relayCtx, upstreamConn, @@ -202,6 +224,12 @@ func Relay( state, options.OnUsageParseFailure, options.OnTurnComplete, + options.BeforeWriteClient, + func() { + if options.StartClientAfterFirstDownstream { + startClientReader() + } + }, &dropDownstreamWrites, upstreamToClientFrames, droppedDownstreamFrames, @@ -230,7 +258,9 @@ func Relay( } else { relayCancel() _ = upstreamConn.Close() - secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + if clientReaderStarted.Load() { + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } } if hasSecondExit { combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream @@ -250,6 +280,14 @@ func Relay( result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() result.UpstreamToClientFrames = upstreamToClientFrames.Load() result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if options.FirstMessageSent && firstExit.stage == "read_client" && firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_client_closed", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + return result, nil + } if firstExit.stage == "read_client" && firstExit.graceful { stage := "client_disconnected" exitErr := firstExit.err @@ -310,6 +348,14 @@ func Relay( WroteDownstream: combinedWroteDownstream, } } + if options.FirstMessageSent { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_client_closed", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + return result, nil + } emitRelayTrace(onTrace, RelayTraceEvent{ Stage: "relay_complete", Graceful: true, @@ -322,14 +368,20 @@ func Relay( func runClientToUpstream( ctx context.Context, clientConn FrameConn, + readClientFrame func(context.Context, FrameConn) (coderws.MessageType, []byte, error), writeUpstream func(msgType coderws.MessageType, payload []byte) error, markActivity func(), forwardedFrames *atomic.Int64, onTrace func(event RelayTraceEvent), exitCh chan<- relayExitSignal, ) { + if readClientFrame == nil { + readClientFrame = func(ctx context.Context, conn FrameConn) (coderws.MessageType, []byte, error) { + return conn.ReadFrame(ctx) + } + } for { - msgType, payload, err := clientConn.ReadFrame(ctx) + msgType, payload, err := readClientFrame(ctx, clientConn) if err != nil { emitRelayTrace(onTrace, RelayTraceEvent{ Stage: "read_client_failed", @@ -368,6 +420,8 @@ func runUpstreamToClient( state *relayState, onUsageParseFailure func(eventType string, usageRaw string), onTurnComplete func(turn RelayTurnResult), + beforeWriteClient func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error, + afterWriteClient func(), dropDownstreamWrites *atomic.Bool, forwardedFrames *atomic.Int64, droppedFrames *atomic.Int64, @@ -395,6 +449,24 @@ func runUpstreamToClient( return } markActivity() + if beforeWriteClient != nil { + if err := beforeWriteClient(msgType, payload, wroteDownstream); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "upstream_message_rejected", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{ + stage: "upstream_message", + err: err, + wroteDownstream: wroteDownstream, + } + return + } + } observedEvent := observedUpstreamEvent{} switch msgType { case coderws.MessageText: @@ -438,6 +510,9 @@ func runUpstreamToClient( return } wroteDownstream = true + if afterWriteClient != nil { + afterWriteClient() + } if forwardedFrames != nil { forwardedFrames.Add(1) } diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go index 123e10ce..52104482 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -45,6 +45,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { runClientToUpstream( context.Background(), newPassthroughTestFrameConn(nil, true), + nil, func(_ coderws.MessageType, _ []byte) error { return nil }, func() {}, nil, @@ -65,6 +66,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { newPassthroughTestFrameConn([]passthroughTestFrame{ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, }, true), + nil, func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, func() {}, nil, @@ -87,6 +89,7 @@ func TestRunClientToUpstream_ErrorPaths(t *testing.T) { newPassthroughTestFrameConn([]passthroughTestFrame{ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, }, true), + nil, func(_ coderws.MessageType, _ []byte) error { return nil }, func() {}, forwarded, @@ -120,6 +123,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, nil, @@ -149,6 +154,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, nil, @@ -181,6 +188,8 @@ func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { &relayState{}, nil, nil, + nil, + nil, drop, nil, dropped, diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 0a89e2dd..17543dc0 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -280,6 +280,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) } if blocked != nil { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) // coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires // writeFrameMu, writes the entire frame, and Flushes the underlying // bufio writer before returning (write.go:42 → write.go:307-311). @@ -358,6 +359,13 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( statusCode, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), ) + if statusCode == http.StatusTooManyRequests { + s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error())) + return &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseHeaders: cloneHeader(handshakeHeaders), + } + } return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) } defer func() { @@ -442,6 +450,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( return out, blocked, policyErr }, onBlock: func(blocked *OpenAIFastBlockedError) { + MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) // See note above on Conn.Write being synchronous w.r.t. flush; // no explicit flush is required to ensure the error event lands // before the close frame. @@ -454,15 +463,46 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( cancel() }, } + upstreamFirstMessageSent := false + firstWriteCtx, cancelFirstWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + firstWriteErr := upstreamFrameConn.WriteFrame(firstWriteCtx, coderws.MessageText, firstClientMessage) + cancelFirstWrite() + if firstWriteErr != nil { + return wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write first upstream websocket request: %w", firstWriteErr), + false, + ) + } + upstreamFirstMessageSent = true + + readNextClientFrame := func(readCtx context.Context, conn openaiwsv2.FrameConn) (coderws.MessageType, []byte, error) { + for { + msgType, payload, readErr := conn.ReadFrame(readCtx) + if readErr != nil { + return msgType, payload, readErr + } + if msgType == coderws.MessageText && strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + return msgType, payload, nil + } + if writeErr := upstreamFrameConn.WriteFrame(readCtx, msgType, payload); writeErr != nil { + return msgType, payload, writeErr + } + } + } + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ Ctx: ctx, ClientConn: policyClientConn, UpstreamConn: upstreamFrameConn, FirstClientMessage: firstClientMessage, Options: openaiwsv2.RelayOptions{ - WriteTimeout: s.openAIWSWriteTimeout(), - IdleTimeout: s.openAIWSPassthroughIdleTimeout(), - FirstMessageType: coderws.MessageText, + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: coderws.MessageText, + FirstMessageSent: upstreamFirstMessageSent, + StartClientAfterFirstDownstream: true, + ReadClientFrame: readNextClientFrame, OnUsageParseFailure: func(eventType string, usageRaw string) { logOpenAIWSV2Passthrough( "usage_parse_failed event_type=%s usage_raw=%s", @@ -505,6 +545,31 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( hooks.AfterTurn(turnNo, turnResult, nil) } }, + BeforeWriteClient: func(msgType coderws.MessageType, payload []byte, wroteDownstream bool) error { + if msgType != coderws.MessageText || wroteDownstream { + return nil + } + if eventType, _, _ := parseOpenAIWSEventEnvelope(payload); eventType != "error" { + return nil + } + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(payload) + if !isOpenAIWSRateLimitError(errCodeRaw, errTypeRaw, errMsgRaw) { + return nil + } + s.persistOpenAIWSRateLimitSignal(ctx, account, handshakeHeaders, payload, errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSV2Passthrough( + "relay_rate_limit_failover account_id=%d err_code=%s err_type=%s err_message=%s", + account.ID, + truncateOpenAIWSLogValue(errCodeRaw, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(errMsgRaw, openAIWSLogValueMaxLen), + ) + return &UpstreamFailoverError{ + StatusCode: http.StatusTooManyRequests, + ResponseBody: append([]byte(nil), payload...), + ResponseHeaders: cloneHeader(handshakeHeaders), + } + }, OnTrace: func(event openaiwsv2.RelayTraceEvent) { logOpenAIWSV2Passthrough( "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index 6c337071..b654190c 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -538,7 +538,8 @@ SELECT COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429), 0) AS upstream_429, COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529), 0) AS upstream_529 FROM ops_error_logs -WHERE created_at >= $1 AND created_at < $2` +WHERE created_at >= $1 AND created_at < $2 + AND is_count_tokens = FALSE` if err := c.db.QueryRowContext(ctx, q, start, end).Scan( &errorTotal, diff --git a/backend/internal/service/ops_metrics_collector_test.go b/backend/internal/service/ops_metrics_collector_test.go new file mode 100644 index 00000000..5b069c1b --- /dev/null +++ b/backend/internal/service/ops_metrics_collector_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestWriteOpenAIFastPolicyBlockedResponseMarksBusinessLimited(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + writeOpenAIFastPolicyBlockedResponse(c, &OpenAIFastBlockedError{Message: "custom fast policy block"}) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.True(t, HasOpsClientBusinessLimited(c)) + reason, ok := c.Get(OpsClientBusinessLimitedReasonKey) + require.True(t, ok) + require.Equal(t, OpsClientBusinessLimitedReasonLocalPolicyDenied, reason) +} + +func TestOpsMetricsCollectorQueryErrorCountsExcludesCountTokens(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + + collector := &OpsMetricsCollector{db: db} + start := time.Date(2026, 5, 26, 10, 0, 0, 0, time.UTC) + end := start.Add(time.Hour) + + mock.ExpectQuery(`(?s)FROM ops_error_logs\s+WHERE created_at >= \$1 AND created_at < \$2\s+AND is_count_tokens = FALSE`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{ + "error_total", + "business_limited", + "error_sla", + "upstream_excl", + "upstream_429", + "upstream_529", + }).AddRow(int64(5), int64(2), int64(3), int64(1), int64(1), int64(1))) + + errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, err := collector.queryErrorCounts(context.Background(), start, end) + require.NoError(t, err) + require.Equal(t, int64(5), errorTotal) + require.Equal(t, int64(2), businessLimited) + require.Equal(t, int64(3), errorSLA) + require.Equal(t, int64(1), upstreamExcl429529) + require.Equal(t, int64(1), upstream429) + require.Equal(t, int64(1), upstream529) + require.NoError(t, mock.ExpectationsWereMet()) + mock.ExpectClose() + require.NoError(t, db.Close()) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index b4ff0e74..2405f306 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -34,9 +34,13 @@ const ( // Client-side configuration denials should remain visible in ops_error_logs, // but should be excluded from SLA/error-rate calculations. - OpsClientBusinessLimitedKey = "ops_client_business_limited" - OpsClientBusinessLimitedReasonKey = "ops_client_business_limited_reason" - OpsClientBusinessLimitedReasonIPRestriction = "api_key_ip_restriction" + OpsClientBusinessLimitedKey = "ops_client_business_limited" + OpsClientBusinessLimitedReasonKey = "ops_client_business_limited_reason" + OpsClientBusinessLimitedReasonIPRestriction = "api_key_ip_restriction" + OpsClientBusinessLimitedReasonAPIKeyGroupUnavailable = "api_key_group_unavailable" + OpsClientBusinessLimitedReasonAPIKeyGroupUnassigned = "api_key_group_unassigned" + OpsClientBusinessLimitedReasonLocalFeatureGate = "local_feature_gate" + OpsClientBusinessLimitedReasonLocalPolicyDenied = "local_policy_denied" ) func SetOpsLatencyMs(c *gin.Context, key string, value int64) { diff --git a/backend/internal/service/post_billing_platform_test.go b/backend/internal/service/post_billing_platform_test.go new file mode 100644 index 00000000..a3ce0676 --- /dev/null +++ b/backend/internal/service/post_billing_platform_test.go @@ -0,0 +1,81 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +func TestPlatformFromAPIKey_NilSafe(t *testing.T) { + if got := PlatformFromAPIKey(nil); got != "" { + t.Errorf("nil APIKey should yield empty string, got %q", got) + } +} + +func TestPlatformFromAPIKey_NilGroup(t *testing.T) { + k := &APIKey{Group: nil} + if got := PlatformFromAPIKey(k); got != "" { + t.Errorf("APIKey with nil Group should yield empty string, got %q", got) + } +} + +func TestPlatformFromAPIKey_DerivesFromGroup(t *testing.T) { + tests := []struct { + name string + platform string + }{ + {"anthropic", "anthropic"}, + {"openai", "openai"}, + {"gemini", "gemini"}, + {"antigravity", "antigravity"}, + {"empty", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &APIKey{ + Group: &Group{Platform: tt.platform}, + } + got := PlatformFromAPIKey(k) + if got != tt.platform { + t.Errorf("PlatformFromAPIKey(%q) = %q, want %q", tt.platform, got, tt.platform) + } + }) + } +} + +// TestQuotaPlatform 锁定配额计量口径:ForcePlatform 路由(如 /antigravity)按 ForcePlatform 计, +// 否则回退到 Group 平台。preflight 与 post-billing 共用此口径,保证一致。 +func TestQuotaPlatform(t *testing.T) { + apiKey := &APIKey{Group: &Group{Platform: PlatformAnthropic}} + + t.Run("no force platform falls back to group platform", func(t *testing.T) { + if got := QuotaPlatform(context.Background(), apiKey); got != PlatformAnthropic { + t.Errorf("QuotaPlatform without force = %q, want %q", got, PlatformAnthropic) + } + }) + + t.Run("force platform overrides group platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity) + if got := QuotaPlatform(ctx, apiKey); got != PlatformAntigravity { + t.Errorf("QuotaPlatform with force = %q, want %q", got, PlatformAntigravity) + } + }) + + t.Run("empty force platform falls back to group platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, "") + if got := QuotaPlatform(ctx, apiKey); got != PlatformAnthropic { + t.Errorf("QuotaPlatform with empty force = %q, want %q", got, PlatformAnthropic) + } + }) + + t.Run("nil api key with force platform returns force platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity) + if got := QuotaPlatform(ctx, nil); got != PlatformAntigravity { + t.Errorf("QuotaPlatform(nil) with force = %q, want %q", got, PlatformAntigravity) + } + }) +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index c3b160e7..ecbd86d1 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -153,7 +153,7 @@ func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Accoun // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 -func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { +func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte, requestedModel ...string) (shouldDisable bool) { customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() // 池模式默认不标记本地账号状态;仅当用户显式配置自定义错误码时按本地策略处理。 @@ -169,6 +169,10 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc return false } + if len(requestedModel) > 0 && s.HandleUpstreamModelNotFound(ctx, account, requestedModel[0], statusCode, responseBody) { + return true + } + // 先尝试临时不可调度规则(401除外) // 如果匹配成功,直接返回,不执行后续禁用逻辑 if statusCode != 401 { @@ -244,17 +248,15 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc shouldDisable = true break } - // 2. 设置 expires_at 为当前时间,强制下次请求刷新 token - if account.Credentials == nil { - account.Credentials = make(map[string]any) - } - account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) - if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil { - slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) - } else { - slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) - } - // 3. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取) + // 2. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取) + // 注意:此处不再写回 account.Credentials/expires_at。 + // 原实现使用请求开始时的 account 快照整列覆盖 credentials JSONB(见 + // persistAccountCredentials → accountRepository.UpdateCredentials → SetCredentials), + // 在另一个 worker 刚刷新完 refresh_token 的窄窗口内会把新 refresh_token 回滚为旧值, + // 导致下一周期用旧 refresh_token 调上游拿到 invalid_grant 后, + // tryRecoverFromRefreshRace 重读 DB 发现 currentRT == usedRT 也救不回来,账号被错误 disable。 + // 这里仅依赖 InvalidateToken + SetTempUnschedulable 让账号在冷却期内不被调度, + // 冷却结束后由 token_provider 的 NeedsRefresh / token_refresh_service 走带分布式锁的正路刷新。 msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { msg = "OAuth 401: " + upstreamMsg @@ -1616,9 +1618,51 @@ func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account return s.tryTempUnschedulable(ctx, account, statusCode, responseBody) } +const upstreamModelNotFoundCooldown = 30 * time.Minute +const upstreamModelNotFoundReason = "upstream_404_model_not_found" const tempUnschedBodyMaxBytes = 64 << 10 const tempUnschedMessageMaxBytes = 2048 +func (s *RateLimitService) HandleUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string, statusCode int, responseBody []byte) bool { + if s == nil || account == nil || s.accountRepo == nil { + return false + } + if !account.ShouldHandleErrorCode(statusCode) { + return false + } + if !isUpstreamModelNotFoundError(statusCode, responseBody) { + return false + } + modelKey := modelRateLimitKeyForUpstreamModelNotFound(ctx, account, requestedModel) + if modelKey == "" { + return false + } + resetAt := time.Now().Add(upstreamModelNotFoundCooldown) + if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, resetAt, upstreamModelNotFoundReason); err != nil { + slog.Warn("upstream_model_not_found_set_model_rate_limit_failed", "account_id", account.ID, "model", modelKey, "error", err) + return true + } + slog.Info("upstream_model_not_found_model_rate_limited", "account_id", account.ID, "model", modelKey, "reset_at", resetAt) + return true +} + +func modelRateLimitKeyForUpstreamModelNotFound(ctx context.Context, account *Account, requestedModel string) string { + modelKey := strings.TrimSpace(requestedModel) + if account == nil || modelKey == "" { + return modelKey + } + if account.Platform == PlatformAntigravity { + if resolved := strings.TrimSpace(resolveFinalAntigravityModelKey(ctx, account, modelKey)); resolved != "" { + return resolved + } + return modelKey + } + if mapped := strings.TrimSpace(account.GetMappedModel(modelKey)); mapped != "" { + return mapped + } + return modelKey +} + func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool { if account == nil { return false diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index a964775e..873aaf33 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -129,7 +129,10 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t } // TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError -// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable +// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable。 +// 注意:401 handler 不再回写 credentials(避免请求开始时的快照整列覆盖 DB +// 把另一个 worker 刚刷新出来的新 refresh_token 回滚为旧值), +// 因此 updateCredentialsCalls 应当为 0。 func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} @@ -149,7 +152,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) - require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.updateCredentialsCalls) require.Len(t, invalidator.accounts, 1) } @@ -171,7 +174,12 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { require.Empty(t, invalidator.accounts) } -func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) { +// TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials +// 回归测试:确保 401 handler 不再使用请求开始时的 account 快照写回 credentials。 +// 原实现会通过 persistAccountCredentials → UpdateCredentials → SetCredentials +// 整列覆盖 credentials JSONB,在另一个 worker 刚刷新完 refresh_token 的窄窗口内 +// 会把新 refresh_token 回滚为快照中的旧值,导致下一周期拿 invalid_grant 被错误 disable。 +func TestRateLimitService_HandleUpstreamError_OAuth401DoesNotOverwriteCredentials(t *testing.T) { repo := &rateLimitAccountRepoStub{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) account := &Account{ @@ -187,8 +195,9 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t * shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.updateCredentialsCalls) - require.NotEmpty(t, repo.lastCredentials["expires_at"]) + require.Equal(t, 0, repo.updateCredentialsCalls, "401 handler must not write credentials back from the request-start snapshot") + require.Equal(t, 1, repo.tempCalls, "401 handler should still set temp-unschedulable cooldown") + require.Nil(t, repo.lastCredentials, "no credentials should have been persisted") } // 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用, diff --git a/backend/internal/service/ratelimit_service_model_not_found_test.go b/backend/internal/service/ratelimit_service_model_not_found_test.go new file mode 100644 index 00000000..dfd18c5f --- /dev/null +++ b/backend/internal/service/ratelimit_service_model_not_found_test.go @@ -0,0 +1,127 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type modelNotFoundRateLimitCall struct { + accountID int64 + scope string + resetAt time.Time + reason string +} + +type modelNotFoundAccountRepoStub struct { + mockAccountRepoForGemini + tempCalls int + modelRateLimitCalls []modelNotFoundRateLimitCall + modelRateLimitErr error +} + +func (r *modelNotFoundAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +func (r *modelNotFoundAccountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error { + call := modelNotFoundRateLimitCall{ + accountID: id, + scope: scope, + resetAt: resetAt, + } + if len(reason) > 0 { + call.reason = reason[0] + } + r.modelRateLimitCalls = append(r.modelRateLimitCalls, call) + return r.modelRateLimitErr +} + +func TestRateLimitService_HandleUpstreamError_ModelNotFoundUsesModelRateLimit(t *testing.T) { + repo := &modelNotFoundAccountRepoStub{} + svc := &RateLimitService{accountRepo: repo} + account := openAIModelNotFoundTempAccount() + + handled := svc.HandleUpstreamError( + context.Background(), + account, + http.StatusNotFound, + http.Header{}, + []byte(`{"error":{"code":"model_not_found","message":"model not found"}}`), + "gpt-5.4", + ) + + require.True(t, handled) + require.Zero(t, repo.tempCalls) + require.Len(t, repo.modelRateLimitCalls, 1) + call := repo.modelRateLimitCalls[0] + require.Equal(t, account.ID, call.accountID) + require.Equal(t, "gpt-5.4", call.scope) + require.Equal(t, upstreamModelNotFoundReason, call.reason) + require.WithinDuration(t, time.Now().Add(upstreamModelNotFoundCooldown), call.resetAt, 5*time.Second) +} + +func TestRateLimitService_HandleUpstreamError_ModelNotFoundWriteFailureDoesNotTempUnschedule(t *testing.T) { + repo := &modelNotFoundAccountRepoStub{modelRateLimitErr: errors.New("write failed")} + svc := &RateLimitService{accountRepo: repo} + account := openAIModelNotFoundTempAccount() + + handled := svc.HandleUpstreamError( + context.Background(), + account, + http.StatusNotFound, + http.Header{}, + []byte(`{"error":{"code":"model_not_found","message":"model not found"}}`), + "gpt-5.4", + ) + + require.True(t, handled) + require.Zero(t, repo.tempCalls) + require.Len(t, repo.modelRateLimitCalls, 1) +} + +func TestRateLimitService_HandleUpstreamError_Bare404KeepsTempUnschedulablePath(t *testing.T) { + repo := &modelNotFoundAccountRepoStub{} + svc := &RateLimitService{accountRepo: repo} + account := openAIModelNotFoundTempAccount() + + handled := svc.HandleUpstreamError( + context.Background(), + account, + http.StatusNotFound, + http.Header{}, + []byte(`{"error":{"message":"endpoint not found"}}`), + "gpt-5.4", + ) + + require.True(t, handled) + require.Equal(t, 1, repo.tempCalls) + require.Empty(t, repo.modelRateLimitCalls) +} + +func openAIModelNotFoundTempAccount() *Account { + return &Account{ + ID: 101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(http.StatusNotFound), + "keywords": []any{"not found"}, + "duration_minutes": float64(10), + }, + }, + }, + } +} diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index e9de5f71..2318af79 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -140,7 +140,7 @@ func (m *sessionWindowMockRepo) ListSchedulableUngroupedByPlatforms(context.Cont func (m *sessionWindowMockRepo) SetRateLimited(context.Context, int64, time.Time) error { panic("unexpected") } -func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time) error { +func (m *sessionWindowMockRepo) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error { panic("unexpected") } func (m *sessionWindowMockRepo) SetOverloaded(context.Context, int64, time.Time) error { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 5eef2c13..e6f0f2bc 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -165,12 +165,20 @@ type SettingService struct { openAICodexUASF singleflight.Group } +// DefaultPlatformQuotaSetting 单 platform 三档限额(nil = 沿用上层;0 = 显式禁用;>0 = 上限) +type DefaultPlatformQuotaSetting struct { + DailyLimitUSD *float64 `json:"daily"` + WeeklyLimitUSD *float64 `json:"weekly"` + MonthlyLimitUSD *float64 `json:"monthly"` +} + type ProviderDefaultGrantSettings struct { Balance float64 Concurrency int Subscriptions []DefaultSubscriptionSetting GrantOnSignup bool GrantOnFirstBind bool + PlatformQuotas map[string]*DefaultPlatformQuotaSetting // key = platform name } type AuthSourceDefaultSettings struct { @@ -185,62 +193,80 @@ type AuthSourceDefaultSettings struct { } type authSourceDefaultKeySet struct { + // source 是 auth source 标识(如 "email"、"github"),仅用于 parse 时 + // slog.Warn 诊断输出,不再参与 key 拼接(platformQuotas 字段已存完整 key)。 + source string balance string concurrency string subscriptions string grantOnSignup string grantOnFirstBind string + platformQuotas string // SettingKeyAuthSourcePlatformQuotas(source) } var ( emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "email", balance: SettingKeyAuthSourceDefaultEmailBalance, concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("email"), } linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "linuxdo", balance: SettingKeyAuthSourceDefaultLinuxDoBalance, concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("linuxdo"), } oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "oidc", balance: SettingKeyAuthSourceDefaultOIDCBalance, concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("oidc"), } weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "wechat", balance: SettingKeyAuthSourceDefaultWeChatBalance, concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("wechat"), } gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "github", balance: SettingKeyAuthSourceDefaultGitHubBalance, concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency, subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("github"), } googleAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "google", balance: SettingKeyAuthSourceDefaultGoogleBalance, concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency, subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("google"), } dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "dingtalk", balance: SettingKeyAuthSourceDefaultDingTalkBalance, concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency, subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("dingtalk"), } ) @@ -1804,9 +1830,41 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) + // 系统全局 platform quota:整体替换语义(null/缺省 = 不限制)。 + if settings.DefaultPlatformQuotas != nil { + if err := validateDefaultPlatformQuotaMap(settings.DefaultPlatformQuotas); err != nil { + return nil, err + } + blob, err := json.Marshal(settings.DefaultPlatformQuotas) + if err != nil { + return nil, fmt.Errorf("marshal default platform quotas: %w", err) + } + updates[SettingKeyDefaultPlatformQuotas] = string(blob) + } + return updates, nil } +// validateDefaultPlatformQuotaMap 校验 platform quota map 的合法性: +// 平台名须在 AllowedQuotaPlatforms 白名单内,每个非 nil 上限须 finite 且 >= 0。 +// 系统层和 auth-source 层共用此 helper。 +func validateDefaultPlatformQuotaMap(m map[string]*DefaultPlatformQuotaSetting) error { + for platform, pq := range m { + if !IsAllowedQuotaPlatform(platform) { + return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", fmt.Sprintf("unknown platform %q", platform)) + } + if pq == nil { + continue + } + for _, v := range []*float64{pq.DailyLimitUSD, pq.WeeklyLimitUSD, pq.MonthlyLimitUSD} { + if v != nil && (*v < 0 || math.IsNaN(*v) || math.IsInf(*v, 0)) { + return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", "platform quota limit must be a finite non-negative number") + } + } + } + return nil +} + func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { if settings == nil { return nil, nil @@ -1826,6 +1884,26 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett } } + // 校验各 auth source 的 platform quota map(改动 C:对等系统层校验) + for _, pgs := range []struct { + name string + pq map[string]*DefaultPlatformQuotaSetting + }{ + {"email", settings.Email.PlatformQuotas}, + {"linuxdo", settings.LinuxDo.PlatformQuotas}, + {"oidc", settings.OIDC.PlatformQuotas}, + {"wechat", settings.WeChat.PlatformQuotas}, + {"github", settings.GitHub.PlatformQuotas}, + {"google", settings.Google.PlatformQuotas}, + {"dingtalk", settings.DingTalk.PlatformQuotas}, + } { + if pgs.pq != nil { + if err := validateDefaultPlatformQuotaMap(pgs.pq); err != nil { + return nil, err + } + } + } + updates := make(map[string]string, 36) writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) @@ -2386,6 +2464,13 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut SettingKeyAuthSourceDefaultDingTalkSubscriptions, SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, + SettingKeyAuthSourcePlatformQuotas("email"), + SettingKeyAuthSourcePlatformQuotas("linuxdo"), + SettingKeyAuthSourcePlatformQuotas("oidc"), + SettingKeyAuthSourcePlatformQuotas("wechat"), + SettingKeyAuthSourcePlatformQuotas("github"), + SettingKeyAuthSourcePlatformQuotas("google"), + SettingKeyAuthSourcePlatformQuotas("dingtalk"), SettingKeyForceEmailOnThirdPartySignup, } @@ -3179,6 +3264,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.AccountQuotaNotifyEmails = []NotifyEmailEntry{} } + // 系统层默认 platform quota(修复 Bug B:parseSettings 不填充导致回显恒为 nil) + if raw := settings[SettingKeyDefaultPlatformQuotas]; raw != "" { + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] parseSettings: unmarshal default_platform_quotas failed", "error", err) + } else { + result.DefaultPlatformQuotas = parsed + } + } + return result } @@ -3271,6 +3366,15 @@ func parseProviderDefaultGrantSettings(settings map[string]string, keys authSour result.GrantOnFirstBind = raw == "true" } + if raw := settings[keys.platformQuotas]; raw != "" { + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] parseProviderDefaultGrantSettings: unmarshal auth source platform quotas failed", "source", keys.source, "error", err) + } else { + result.PlatformQuotas = parsed + } + } + return result } @@ -3289,6 +3393,17 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource updates[keys.subscriptions] = string(raw) updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) + + // auth source platform quota:整体替换语义。 + // nil = 请求未携带该字段,跳过写入以保留既有配置(与系统层 buildSystemSettingsUpdates 的 + // DefaultPlatformQuotas nil 守卫一致);非 nil(含空 map)才整体替换。二者语义不可混同。 + if keys.platformQuotas != "" && settings.PlatformQuotas != nil { + blob, err := json.Marshal(settings.PlatformQuotas) + if err != nil { + blob = []byte("{}") + } + updates[keys.platformQuotas] = string(blob) + } } func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { @@ -4493,3 +4608,63 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +// GetDefaultPlatformQuotas 读取系统全局 platform quota JSON key,返回 4 platform x 3 window 的设置。 +// 永远返回包含全部 4 platform key 的 map(值可能为零值/nil 字段,表示"上层未配置 = 不限制")。 +// +// 使用单个 JSON key(default_platform_quotas),一次 DB roundtrip,消除旧 12-KV 格式的 N+1 问题。 +// 容错语义:取值失败或 unmarshal 失败 → 返回补齐 4 key 的空 map(fail-open,注册不被阻断)。 +func (s *SettingService) GetDefaultPlatformQuotas(ctx context.Context) (map[string]*DefaultPlatformQuotaSetting, error) { + out := map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {}, + "openai": {}, + "gemini": {}, + "antigravity": {}, + } + raw, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultPlatformQuotas) + if err != nil || raw == "" { + return out, nil // 无配置 = 全部不限制 + } + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] unmarshal default_platform_quotas failed (fail-open)", "error", err) + return out, nil + } + for _, platform := range AllowedQuotaPlatforms { + if v := parsed[platform]; v != nil { + out[platform] = v + } + } + return out, nil // 补齐 4 platform key,保持与旧实现一致的下游契约 +} + +// GetAuthSourcePlatformQuotas 读取指定 auth source 的 platform quota 覆盖(仅返回有配置的平台,override 语义)。 +func (s *SettingService) GetAuthSourcePlatformQuotas(ctx context.Context, source string) map[string]*DefaultPlatformQuotaSetting { + out := map[string]*DefaultPlatformQuotaSetting{} + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAuthSourcePlatformQuotas(source)) + if err != nil || raw == "" { + return out // 无 override + } + if err := json.Unmarshal([]byte(raw), &out); err != nil { + slog.Warn("[Setting] unmarshal auth source platform quotas failed (fail-open)", "source", source, "error", err) + return map[string]*DefaultPlatformQuotaSetting{} + } + return out // 仅含已配置平台,保持 override 语义 +} + +// mergePlatformQuotaDefaults 按字段级 patch:src 中非 nil 字段覆盖 dst。 +// 区分 nil("未配置",保留 dst)vs &0.0("显式禁用",覆盖 dst 为 0) +func mergePlatformQuotaDefaults(dst, src *DefaultPlatformQuotaSetting) { + if src == nil || dst == nil { + return + } + if src.DailyLimitUSD != nil { + dst.DailyLimitUSD = src.DailyLimitUSD + } + if src.WeeklyLimitUSD != nil { + dst.WeeklyLimitUSD = src.WeeklyLimitUSD + } + if src.MonthlyLimitUSD != nil { + dst.MonthlyLimitUSD = src.MonthlyLimitUSD + } +} diff --git a/backend/internal/service/setting_service_platform_quota_test.go b/backend/internal/service/setting_service_platform_quota_test.go new file mode 100644 index 00000000..557cc5f1 --- /dev/null +++ b/backend/internal/service/setting_service_platform_quota_test.go @@ -0,0 +1,344 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestMergePlatformQuotaDefaults_PatchSemantics(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{ + DailyLimitUSD: &five, + WeeklyLimitUSD: &five, + } + ten := 10.0 + patch := DefaultPlatformQuotaSetting{DailyLimitUSD: &ten} + + mergePlatformQuotaDefaults(&base, &patch) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 10.0 { + t.Errorf("daily not patched: %+v", base.DailyLimitUSD) + } + if base.WeeklyLimitUSD == nil || *base.WeeklyLimitUSD != 5.0 { + t.Errorf("weekly should remain 5.0: %+v", base.WeeklyLimitUSD) + } +} + +func TestMergePlatformQuotaDefaults_ZeroIsExplicitDisable(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{DailyLimitUSD: &five} + zero := 0.0 + patch := DefaultPlatformQuotaSetting{DailyLimitUSD: &zero} + + mergePlatformQuotaDefaults(&base, &patch) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 0 { + t.Errorf("explicit 0 should patch base, got %+v", base.DailyLimitUSD) + } +} + +func TestMergePlatformQuotaDefaults_NilSrcIsNoop(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{DailyLimitUSD: &five} + mergePlatformQuotaDefaults(&base, nil) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 5.0 { + t.Errorf("nil src should be no-op: %+v", base.DailyLimitUSD) + } +} + +func floatPtrPQ(v float64) *float64 { return &v } + +func newSettingServiceForPlatformQuotaTest(seed map[string]string) *SettingService { + repo := newMockSettingRepo() + for k, v := range seed { + repo.data[k] = v + } + return NewSettingService(repo, &config.Config{}) +} + +func TestGetDefaultPlatformQuotas_ReturnsFourPlatforms(t *testing.T) { + zero := 0.0 + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + // 新 JSON 格式:anthropic daily=10.5, openai monthly=0, gemini/antigravity 无配置 + SettingKeyDefaultPlatformQuotas: `{"anthropic":{"daily":10.5},"openai":{"monthly":0}}`, + }) + got, err := svc.GetDefaultPlatformQuotas(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // 必须包含全部 4 个 platform key(补齐契约) + for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[platform]; !ok { + t.Errorf("missing platform key: %q", platform) + } + } + // anthropic daily = 10.5 + if v := got["anthropic"].DailyLimitUSD; v == nil || *v != 10.5 { + t.Errorf("anthropic daily want 10.5, got %v", v) + } + // openai monthly = 0(显式禁用) + if v := got["openai"].MonthlyLimitUSD; v == nil || *v != zero { + t.Errorf("openai monthly want 0 (explicit disable), got %v", v) + } + // gemini 无配置 → weekly = nil + if v := got["gemini"].WeeklyLimitUSD; v != nil { + t.Errorf("gemini weekly want nil (not configured), got %v", *v) + } + // antigravity 无配置 → daily = nil + if v := got["antigravity"].DailyLimitUSD; v != nil { + t.Errorf("antigravity daily want nil (not configured), got %v", *v) + } +} + +func TestGetAuthSourcePlatformQuotas_OnlyConfiguredReturned(t *testing.T) { + source := "email" + // 新 JSON 格式:anthropic daily=5, monthly=100;openai weekly=0;gemini/antigravity 无配置 + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas(source): `{"anthropic":{"daily":5,"monthly":100},"openai":{"weekly":0}}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), source) + + // anthropic 有配置 → 在结果中 + anthro, ok := got["anthropic"] + if !ok { + t.Fatal("expected anthropic to be present") + } + if anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("anthropic daily want 5.0, got %v", anthro.DailyLimitUSD) + } + if anthro.MonthlyLimitUSD == nil || *anthro.MonthlyLimitUSD != 100.0 { + t.Errorf("anthropic monthly want 100.0, got %v", anthro.MonthlyLimitUSD) + } + if anthro.WeeklyLimitUSD != nil { + t.Errorf("anthropic weekly not configured, want nil, got %v", *anthro.WeeklyLimitUSD) + } + + // openai weekly=0 → 在结果中 + oai, ok := got["openai"] + if !ok { + t.Fatal("expected openai to be present") + } + if oai.WeeklyLimitUSD == nil || *oai.WeeklyLimitUSD != 0 { + t.Errorf("openai weekly want 0, got %v", oai.WeeklyLimitUSD) + } + + // gemini / antigravity 无配置 → 不在结果中(override 语义) + if _, ok := got["gemini"]; ok { + t.Error("gemini not configured, should be absent from result") + } + if _, ok := got["antigravity"]; ok { + t.Error("antigravity not configured, should be absent from result") + } +} + +func TestGetAuthSourcePlatformQuotas_AllNegativeOrEmpty_NoEntry(t *testing.T) { + source := "linuxdo" + // 新 JSON 格式:未配置任何平台(空 JSON key)→ 返回空 map + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas(source): `{}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), source) + // 空 map → override 语义,无 openai 条目 + if _, ok := got["openai"]; ok { + t.Error("empty JSON object should result in no openai entry") + } + if len(got) != 0 { + t.Errorf("expected empty result map, got %v", got) + } +} + +// TestSystemPlatformQuotas_WriteReadRoundTrip 验证系统层 platform quota 经 buildSystemSettingsUpdates(写) +// 再由 GetDefaultPlatformQuotas(读)正确往返——覆盖真实 write→read 路径,锁住 4-key 补齐契约。 +func TestSystemPlatformQuotas_WriteReadRoundTrip(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + ctx := context.Background() + + ten := 10.0 + ss := &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten, WeeklyLimitUSD: nil, MonthlyLimitUSD: nil}, + }, + } + if err := svc.UpdateSettings(ctx, ss); err != nil { + t.Fatalf("UpdateSettings: %v", err) + } + + got, err := svc.GetDefaultPlatformQuotas(ctx) + if err != nil { + t.Fatal(err) + } + // 4-key 补齐契约:无论写了几个 platform,读回必须含全部 4 个 + for _, p := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[p]; !ok { + t.Errorf("4-key contract violated: missing platform %q", p) + } + } + // 写入值正确往返 + if v := got["anthropic"].DailyLimitUSD; v == nil || *v != ten { + t.Fatalf("anthropic daily round-trip failed: got %v, want 10", v) + } + // 未写入的平台字段为 nil + if got["openai"].DailyLimitUSD != nil { + t.Errorf("openai daily should be nil (not written), got %v", got["openai"].DailyLimitUSD) + } +} + +// TestSystemPlatformQuotas_EmptyMapClearsAll 验证空 map 的整体替换语义: +// 写入 DefaultPlatformQuotas={} 后,GetDefaultPlatformQuotas 返回 4 个平台、所有字段均为 nil, +// 明确文档化"空 map = 清空全部配额"是有意为之的 whole-replace 语义。 +func TestSystemPlatformQuotas_EmptyMapClearsAll(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + ctx := context.Background() + + // 先写入有值的配置 + ten := 10.0 + if err := svc.UpdateSettings(ctx, &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + }); err != nil { + t.Fatalf("initial write: %v", err) + } + + // 再写入空 map(整体替换语义:清空全部) + if err := svc.UpdateSettings(ctx, &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{}, + }); err != nil { + t.Fatalf("empty map write: %v", err) + } + + got, err := svc.GetDefaultPlatformQuotas(ctx) + if err != nil { + t.Fatal(err) + } + // 4 个 key 仍然存在(补齐契约) + for _, p := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[p]; !ok { + t.Errorf("4-key contract violated after empty write: missing %q", p) + } + } + // 所有字段 nil(全部已清空) + for _, p := range AllowedQuotaPlatforms { + pq := got[p] + if pq == nil { + continue + } + if pq.DailyLimitUSD != nil || pq.WeeklyLimitUSD != nil || pq.MonthlyLimitUSD != nil { + t.Errorf("platform %q should have all-nil limits after empty-map write, got %+v", p, pq) + } + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_PlatformQuotaRoundTrip 验证 round-4 fix: +// PUT /admin/settings 携带的 auth source × platform × window 限额能完整写入并被 GetAuthSourcePlatformQuotas 读回。 +// Round-4 之前 writeProviderDefaultGrantUpdates 完全没写 PQ key,前端配置静默丢失。 +func TestUpdateSettingsWithAuthSourceDefaults_PlatformQuotaRoundTrip(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + systemSettings := &SystemSettings{} + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": { + DailyLimitUSD: floatPtrPQ(5.0), + WeeklyLimitUSD: nil, // 无限额 + MonthlyLimitUSD: floatPtrPQ(100.0), + }, + "openai": { + DailyLimitUSD: floatPtrPQ(0), // 显式禁用 + }, + }, + }, + } + if err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), systemSettings, authDefaults); err != nil { + t.Fatalf("UpdateSettingsWithAuthSourceDefaults: %v", err) + } + got := svc.GetAuthSourcePlatformQuotas(context.Background(), "email") + anthro := got["anthropic"] + if anthro == nil || anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("anthropic daily round-trip failed: %+v", anthro) + } + if anthro != nil && anthro.WeeklyLimitUSD != nil { + t.Errorf("anthropic weekly want nil (无限额), got %v", *anthro.WeeklyLimitUSD) + } + if anthro == nil || anthro.MonthlyLimitUSD == nil || *anthro.MonthlyLimitUSD != 100.0 { + t.Errorf("anthropic monthly round-trip failed: %+v", anthro) + } + oai := got["openai"] + if oai == nil || oai.DailyLimitUSD == nil || *oai.DailyLimitUSD != 0 { + t.Errorf("openai daily=0 (禁用) round-trip failed: %+v", oai) + } + // 其他 source 不应有 quota(authDefaults 只填了 Email) + if linux := svc.GetAuthSourcePlatformQuotas(context.Background(), "linuxdo"); len(linux) != 0 { + t.Errorf("linuxdo should be empty, got %+v", linux) + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_NilPlatformQuotaPreservesExisting 验证 #2 防御: +// 请求未携带某 auth source 的 platform quota(nil)时跳过写入、保留既有配置, +// 而非整体替换为空 map 清空(与系统层 nil 守卫一致)。 +func TestUpdateSettingsWithAuthSourceDefaults_NilPlatformQuotaPreservesExisting(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas("email"): `{"anthropic":{"daily":5,"weekly":null,"monthly":null}}`, + }) + // authDefaults 不携带 Email 的 PlatformQuotas(nil)——应保留既有配置 + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{PlatformQuotas: nil}, + } + if err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), &SystemSettings{}, authDefaults); err != nil { + t.Fatalf("UpdateSettingsWithAuthSourceDefaults: %v", err) + } + anthro := svc.GetAuthSourcePlatformQuotas(context.Background(), "email")["anthropic"] + if anthro == nil || anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("nil PlatformQuotas 应保留既有 anthropic daily=5,got %+v", anthro) + } +} + +// TestGetAuthSourcePlatformQuotas_JSON 验证新 JSON key 读写语义: +// 写入 JSON,断言已配置平台在结果中、未配置平台不在结果中(override 语义)。 +func TestGetAuthSourcePlatformQuotas_JSON(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas("email"): `{"openai":{"daily":null,"weekly":null,"monthly":20}}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), "email") + + // openai monthly = 20 + oai, ok := got["openai"] + if !ok { + t.Fatal("expected openai to be present") + } + if oai.MonthlyLimitUSD == nil || *oai.MonthlyLimitUSD != 20 { + t.Errorf("openai monthly want 20, got %v", oai.MonthlyLimitUSD) + } + if oai.DailyLimitUSD != nil { + t.Errorf("openai daily want nil, got %v", *oai.DailyLimitUSD) + } + if oai.WeeklyLimitUSD != nil { + t.Errorf("openai weekly want nil, got %v", *oai.WeeklyLimitUSD) + } + + // anthropic 未配置 → 不在结果中(override 语义) + if _, ok := got["anthropic"]; ok { + t.Error("anthropic not configured, should be absent from result") + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_NegativeQuotaRejected 验证改动 C: +// auth-source platform quota 含负数时,UpdateSettingsWithAuthSourceDefaults 返回 BadRequest 错误。 +func TestUpdateSettingsWithAuthSourceDefaults_NegativeQuotaRejected(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + neg := -1.0 + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &neg}, + }, + }, + } + err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), &SystemSettings{}, authDefaults) + require.Error(t, err, "expected error for negative quota") + require.Equal(t, "INVALID_DEFAULT_PLATFORM_QUOTA", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index c9bea224..3f961ab2 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -219,6 +219,9 @@ type SystemSettings struct { // 账号限额通知 AccountQuotaNotifyEnabled bool AccountQuotaNotifyEmails []NotifyEmailEntry + + // 系统全局默认平台配额(key = platform,nil/缺省 = 不限制) + DefaultPlatformQuotas map[string]*DefaultPlatformQuotaSetting `json:"default_platform_quotas"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/service/user_platform_quota_port.go b/backend/internal/service/user_platform_quota_port.go new file mode 100644 index 00000000..cb09542a --- /dev/null +++ b/backend/internal/service/user_platform_quota_port.go @@ -0,0 +1,50 @@ +package service + +import ( + "context" + "errors" + "time" +) + +// ErrUserPlatformQuotaNotFound service 层 sentinel:quota 记录不存在。 +// adapter 将 repository.ErrUserPlatformQuotaNotFound 包装为此错误, +// handler 只需引用 service 包,无需直接依赖 repository 包。 +var ErrUserPlatformQuotaNotFound = errors.New("user platform quota not found") + +// UserPlatformQuotaRecord service 层传输结构体(与 repository 层解耦)。 +type UserPlatformQuotaRecord struct { + UserID int64 + Platform string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + // 窗口起始时间(可选,用于未来 reset 校验) + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + +// UserPlatformQuotaRepository 定义 service 层所需的 user × platform quota 数据访问端口。 +// repository 包的 userPlatformQuotaRepository 实现此接口。 +type UserPlatformQuotaRepository interface { + // GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。 + GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) + // BulkInsertInitial 幂等批量插入初始配额记录(ON CONFLICT DO NOTHING)。 + BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error + // IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。 + IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error + // ListByUser 查询用户的所有平台配额记录。 + ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) + // UpsertForUser 全量替换该用户所有平台限额配置(事务内): + // 1. 软删除未在 records 中出现的所有 active 行 + // 2. 对 records 中每条:UPDATE 已存在的(含重新激活软删行);UPDATE 未命中时 INSERT + // 仅改 *_limit_usd + deleted_at + updated_at,保留 *_usage_usd / *_window_start。 + // records 为空时仅执行步骤 1。 + UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error + // ResetExpiredWindow 重置指定窗口("daily"|"weekly"|"monthly")的用量与起始时间。 + // 未命中活跃记录时返回(service-side wrapper of repository.ErrUserPlatformQuotaNotFound)。 + ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error +} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 775dd602..19aec5d3 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -202,7 +202,7 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { out := make([]UserAuthIdentityRecord, len(m.identities)) copy(out, m.identities) @@ -315,6 +315,22 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err return nil } +func (m *mockBillingCache) GetUserPlatformQuotaCache(context.Context, int64, string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (m *mockBillingCache) SetUserPlatformQuotaCache(context.Context, int64, string, *UserPlatformQuotaCacheEntry, time.Duration) error { + return nil +} + +func (m *mockBillingCache) DeleteUserPlatformQuotaCache(context.Context, int64, string) error { + return nil +} + +func (m *mockBillingCache) IncrUserPlatformQuotaUsageCache(context.Context, int64, string, float64, time.Duration) error { + return nil +} + // --- 测试 --- func TestUpdateBalance_Success(t *testing.T) { diff --git a/backend/internal/service/windsurf_google_login_test.go b/backend/internal/service/windsurf_google_login_test.go index de0b6647..1940abf8 100644 --- a/backend/internal/service/windsurf_google_login_test.go +++ b/backend/internal/service/windsurf_google_login_test.go @@ -102,7 +102,7 @@ func (*tokenLoginRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, func (*tokenLoginRepoStub) SetRateLimited(context.Context, int64, time.Time) error { panic("unexpected") } -func (*tokenLoginRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error { +func (*tokenLoginRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error { panic("unexpected") } func (*tokenLoginRepoStub) SetOverloaded(context.Context, int64, time.Time) error { diff --git a/backend/internal/service/windsurf_tier_access_service_test.go b/backend/internal/service/windsurf_tier_access_service_test.go index f1e6ee93..b9e4402e 100644 --- a/backend/internal/service/windsurf_tier_access_service_test.go +++ b/backend/internal/service/windsurf_tier_access_service_test.go @@ -97,7 +97,7 @@ func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, func (*tierAccessRepoStub) SetRateLimited(context.Context, int64, time.Time) error { panic("unexpected") } -func (*tierAccessRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error { +func (*tierAccessRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error { panic("unexpected") } func (*tierAccessRepoStub) SetOverloaded(context.Context, int64, time.Time) error { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 94eb5d20..0c6c782d 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -420,8 +420,9 @@ func ProvideBillingCacheService( rpmCache UserRPMCache, rateRepo UserGroupRateRepository, cfg *config.Config, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *BillingCacheService { - return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) + return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg, userPlatformQuotaRepo) } // ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation. diff --git a/backend/migrations/142_user_platform_quotas.sql b/backend/migrations/142_user_platform_quotas.sql new file mode 100644 index 00000000..34375edd --- /dev/null +++ b/backend/migrations/142_user_platform_quotas.sql @@ -0,0 +1,36 @@ +-- 用户平台维度配额表。每个 (user_id, platform) 对独立记录日/周/月三层 USD 限额与用量。 +-- nil limit = 不限制(沿用上层默认),0 = 显式禁用,>0 = USD 上限。 +-- 软删除:deleted_at IS NULL 的记录为活跃记录,部分唯一索引保证同用户同平台只有一条活跃配额。 + +CREATE TABLE IF NOT EXISTS user_platform_quotas ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + platform VARCHAR(32) NOT NULL CHECK (platform IN ('anthropic', 'openai', 'gemini', 'antigravity')), + + -- 日 / 周 / 月 USD 上限:NULL = 不限制,0 = 显式禁用,>0 = 上限 + daily_limit_usd DECIMAL(20,10), + weekly_limit_usd DECIMAL(20,10), + monthly_limit_usd DECIMAL(20,10), + + -- 当前窗口已用量 + daily_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + weekly_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + monthly_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + + -- 窗口起点(NULL = 首次尚未初始化) + daily_window_start TIMESTAMPTZ, + weekly_window_start TIMESTAMPTZ, + monthly_window_start TIMESTAMPTZ, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); + +-- 软删除友好唯一索引:同用户同平台只允许一条未删除记录 +CREATE UNIQUE INDEX IF NOT EXISTS userplatformquota_user_id_platform_uq + ON user_platform_quotas (user_id, platform) + WHERE deleted_at IS NULL; + +CREATE INDEX IF NOT EXISTS userplatformquota_user_id + ON user_platform_quotas (user_id); diff --git a/backend/migrations/143_group_models_list_config.sql b/backend/migrations/143_group_models_list_config.sql new file mode 100644 index 00000000..67f27623 --- /dev/null +++ b/backend/migrations/143_group_models_list_config.sql @@ -0,0 +1,5 @@ +-- 分组级自定义 /v1/models 展示列表配置。 +-- 仅用于控制 GET /v1/models 的展示结果,不参与账号白名单、模型映射或网关调度。 + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS models_list_config JSONB NOT NULL DEFAULT '{}'::jsonb; diff --git a/deploy/.env.example b/deploy/.env.example index b8e85dab..22041123 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -254,6 +254,14 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10 # # 默认:false GATEWAY_FORCE_CODEX_CLI=false +# OpenAI/Codex 等待上游响应头超时(秒);0 表示不使用本地响应头超时截断。 +GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT=0 +# OpenAI HTTP 上游默认启用 HTTP/2;如需紧急回滚可设为 false。 +GATEWAY_OPENAI_HTTP2_ENABLED=true +GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1=true +GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD=2 +GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS=60 +GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS=600 # 上游连接池:每主机最大连接数(默认 1024;流式/HTTP1.1 场景可调大,如 2400/4096) GATEWAY_MAX_CONNS_PER_HOST=2048 # 上游连接池:最大空闲连接总数(默认 2560;账号/代理隔离 + 高并发场景可调大) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8e9b0e3b..31b38a19 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -149,6 +149,9 @@ gateway: # Timeout for waiting upstream response headers (seconds) # 等待上游响应头超时时间(秒) response_header_timeout: 600 + # OpenAI/Codex upstream response header timeout (seconds, 0=disabled) + # OpenAI/Codex 等待上游响应头超时时间(秒,0=禁用本地响应头超时) + openai_response_header_timeout: 0 # Max request body size in bytes (default: 256MB) # 请求体最大字节数(默认 256MB) max_body_size: 268435456 @@ -317,6 +320,14 @@ gateway: queue: 0.7 error_rate: 0.8 ttft: 0.5 + # OpenAI HTTP upstream protocol strategy. + # OpenAI HTTP 上游协议策略(默认 HTTP/2;代理明确不兼容时可临时回退 HTTP/1.1)。 + openai_http2: + enabled: true + allow_proxy_fallback_to_http1: true + fallback_error_threshold: 2 + fallback_window_seconds: 60 + fallback_ttl_seconds: 600 # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts @@ -1001,6 +1012,9 @@ billing: # Number of requests to allow in half-open state # 半开状态允许通过的请求数 half_open_requests: 3 + # Cache TTL (seconds) for per-user × per-platform quota records + # 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口 + user_platform_quota_cache_ttl_seconds: 86400 # ============================================================================= # Turnstile Configuration diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index b7a805b5..47e0bcad 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -40,6 +40,13 @@ services: - JWT_SECRET=${JWT_SECRET:-} - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} - TZ=${TZ:-Asia/Shanghai} + # OpenAI HTTP upstream protocol/timeout + - GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT=${GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT:-0} + - GATEWAY_OPENAI_HTTP2_ENABLED=${GATEWAY_OPENAI_HTTP2_ENABLED:-true} + - GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1=${GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1:-true} + - GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD=${GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD:-2} + - GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS:-60} + - GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS:-600} - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 40f97532..47c64909 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -152,6 +152,13 @@ services: # ======================================================================= # Image Generation Stream & Concurrency # ======================================================================= + # OpenAI HTTP upstream protocol/timeout + - GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT=${GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT:-0} + - GATEWAY_OPENAI_HTTP2_ENABLED=${GATEWAY_OPENAI_HTTP2_ENABLED:-true} + - GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1=${GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1:-true} + - GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD=${GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD:-2} + - GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS:-60} + - GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS:-600} - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index 44383dbe..32afb28d 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -98,6 +98,13 @@ services: # ======================================================================= # Image Generation Stream & Concurrency # ======================================================================= + # OpenAI HTTP upstream protocol/timeout + - GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT=${GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT:-0} + - GATEWAY_OPENAI_HTTP2_ENABLED=${GATEWAY_OPENAI_HTTP2_ENABLED:-true} + - GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1=${GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1:-true} + - GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD=${GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD:-2} + - GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS:-60} + - GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS:-600} - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 2cfb93f1..ad91744d 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -113,6 +113,13 @@ services: # ======================================================================= # Image Generation Stream & Concurrency # ======================================================================= + # OpenAI HTTP upstream protocol/timeout + - GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT=${GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT:-0} + - GATEWAY_OPENAI_HTTP2_ENABLED=${GATEWAY_OPENAI_HTTP2_ENABLED:-true} + - GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1=${GATEWAY_OPENAI_HTTP2_ALLOW_PROXY_FALLBACK_TO_HTTP1:-true} + - GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD=${GATEWAY_OPENAI_HTTP2_FALLBACK_ERROR_THRESHOLD:-2} + - GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_WINDOW_SECONDS:-60} + - GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS=${GATEWAY_OPENAI_HTTP2_FALLBACK_TTL_SECONDS:-600} - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts index 10f6247a..4dc7db2e 100644 --- a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -3,9 +3,20 @@ import { describe, expect, it } from "vitest"; import { appendAuthSourceDefaultsToUpdateRequest, buildAuthSourceDefaultsState, + normalizePlatformQuotasMap, + sanitizePlatformQuotasMap, type UpdateSettingsRequest, + type DefaultPlatformQuotasMap, } from "@/api/admin/settings"; +/** 全 null 的 4 平台 map,用于断言归一化默认值 */ +const allNullQuotas: DefaultPlatformQuotasMap = { + anthropic: { daily: null, weekly: null, monthly: null }, + openai: { daily: null, weekly: null, monthly: null }, + gemini: { daily: null, weekly: null, monthly: null }, + antigravity: { daily: null, weekly: null, monthly: null }, +} + describe("admin settings auth source defaults helpers", () => { it("builds auth source defaults state from flat settings fields", () => { const state = buildAuthSourceDefaultsState({ @@ -31,6 +42,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 1, validity_days: 30 }], grant_on_signup: false, grant_on_first_bind: true, + platform_quotas: allNullQuotas, }); expect(state.linuxdo).toEqual({ balance: 6, @@ -38,6 +50,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 2, validity_days: 60 }], grant_on_signup: true, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); expect(state.oidc).toEqual({ balance: 0, @@ -45,6 +58,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); expect(state.wechat).toEqual({ balance: 0, @@ -52,6 +66,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); }); @@ -64,6 +79,23 @@ describe("admin settings auth source defaults helpers", () => { expect(state.wechat.grant_on_signup).toBe(false); }); + it("reads nested platform_quotas from settings into auth source state", () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_platform_quotas: { + anthropic: { daily: 10, weekly: 50, monthly: 200 }, + openai: { daily: null, weekly: null, monthly: null }, + } as DefaultPlatformQuotasMap, + }); + + // anthropic 填写的值应被保留 + expect(state.email.platform_quotas.anthropic).toEqual({ daily: 10, weekly: 50, monthly: 200 }); + // openai 全 null 应被保留 + expect(state.email.platform_quotas.openai).toEqual({ daily: null, weekly: null, monthly: null }); + // 未出现的平台(gemini/antigravity)归一化为 null + expect(state.email.platform_quotas.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(state.email.platform_quotas.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); + it("appends auth source defaults back onto update payload", () => { const payload: UpdateSettingsRequest = { site_name: "Sub2API", @@ -76,6 +108,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 3, validity_days: 7 }], grant_on_signup: true, grant_on_first_bind: false, + platform_quotas: {}, }, linuxdo: { balance: 0, @@ -83,6 +116,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: true, + platform_quotas: {}, }, oidc: { balance: 4, @@ -90,6 +124,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 9, validity_days: 90 }], grant_on_signup: true, grant_on_first_bind: true, + platform_quotas: {}, }, wechat: { balance: 2, @@ -97,6 +132,31 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: {}, + }, + github: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, + }, + google: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, + }, + dingtalk: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, }, }); @@ -126,6 +186,111 @@ describe("admin settings auth source defaults helpers", () => { auth_source_default_wechat_subscriptions: [], auth_source_default_wechat_grant_on_signup: false, auth_source_default_wechat_grant_on_first_bind: false, + // 嵌套 platform_quotas 字段 + auth_source_default_email_platform_quotas: allNullQuotas, + auth_source_default_linuxdo_platform_quotas: allNullQuotas, + auth_source_default_oidc_platform_quotas: allNullQuotas, + auth_source_default_wechat_platform_quotas: allNullQuotas, + auth_source_default_github_platform_quotas: allNullQuotas, + auth_source_default_google_platform_quotas: allNullQuotas, + auth_source_default_dingtalk_platform_quotas: allNullQuotas, }); }); + + it("appends sanitized nested platform_quotas with non-null values in update payload", () => { + const payload: UpdateSettingsRequest = {}; + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: { + anthropic: { daily: 10, weekly: 50, monthly: 200 }, + openai: { daily: 0, weekly: null, monthly: null }, + }, + }, + linuxdo: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + oidc: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + wechat: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + github: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + google: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + dingtalk: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + }); + + const emailQuotas = (payload as Record)["auth_source_default_email_platform_quotas"] as DefaultPlatformQuotasMap; + expect(emailQuotas.anthropic).toEqual({ daily: 10, weekly: 50, monthly: 200 }); + // 0 是合法值(不限额=0 与"不设"不同,保留) + expect(emailQuotas.openai?.daily).toBe(0); + // 缺失平台归一化为全 null + expect(emailQuotas.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(emailQuotas.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); +}); + +describe("normalizePlatformQuotasMap", () => { + it("填充缺失的平台为全 null 三档", () => { + const result = normalizePlatformQuotasMap({ anthropic: { daily: 5, weekly: null, monthly: null } }); + expect(result.anthropic).toEqual({ daily: 5, weekly: null, monthly: null }); + expect(result.openai).toEqual({ daily: null, weekly: null, monthly: null }); + expect(result.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(result.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); + + it("无参数时返回全 4 平台全 null", () => { + const result = normalizePlatformQuotasMap(); + expect(Object.keys(result)).toHaveLength(4); + for (const v of Object.values(result)) { + expect(v).toEqual({ daily: null, weekly: null, monthly: null }); + } + }); + + it("非 number 类型的值归一化为 null", () => { + const result = normalizePlatformQuotasMap({ + anthropic: { daily: "50" as unknown as number, weekly: undefined as unknown as number, monthly: null }, + }); + expect(result.anthropic).toEqual({ daily: null, weekly: null, monthly: null }); + }); +}); + +describe("sanitizePlatformQuotasMap", () => { + it("保留合法的正数和零值", () => { + const result = sanitizePlatformQuotasMap({ + anthropic: { daily: 10.5, weekly: 0, monthly: null }, + }); + expect(result.anthropic?.daily).toBe(10.5); + expect(result.anthropic?.weekly).toBe(0); + expect(result.anthropic?.monthly).toBe(null); + }); + + it("空字符串(v-model.number 空输入)清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + anthropic: { daily: "" as unknown as number, weekly: null, monthly: null }, + }); + expect(result.anthropic?.daily).toBe(null); + }); + + it("负数清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + openai: { daily: -1, weekly: null, monthly: null }, + }); + expect(result.openai?.daily).toBe(null); + }); + + it("NaN/Infinity 清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + gemini: { daily: NaN, weekly: Infinity, monthly: null }, + }); + expect(result.gemini?.daily).toBe(null); + expect(result.gemini?.weekly).toBe(null); + }); + + it("缺失平台填充为全 null", () => { + const result = sanitizePlatformQuotasMap({}); + expect(Object.keys(result)).toHaveLength(4); + for (const v of Object.values(result)) { + expect(v).toEqual({ daily: null, weekly: null, monthly: null }); + } + }); }); diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 92b0abca..6fd23c47 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -204,6 +204,30 @@ export async function refreshCredentials(id: number): Promise { return data } +/** + * Apply OAuth credentials after re-authorization. + * + * Unlike `update()`, this endpoint: + * - never overwrites the whole `extra` JSONB (merges incrementally instead), + * so persistent settings like `base_rpm`, `window_cost_limit`, `max_sessions`, + * `quota_*` and `privacy_mode` are preserved + * - clears the account error and invalidates the token cache server-side + */ +export async function applyOAuthCredentials( + id: number, + payload: { + type: 'oauth' | 'setup-token' + credentials: Record + extra?: Record + } +): Promise { + const { data } = await apiClient.post( + `/admin/accounts/${id}/apply-oauth-credentials`, + payload + ) + return data +} + /** * Get account usage statistics * @param id - Account ID @@ -665,6 +689,7 @@ export const accountsAPI = { toggleStatus, testAccount, refreshCredentials, + applyOAuthCredentials, getStats, clearError, getUsage, diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 6b94b799..b7846efd 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -76,6 +76,23 @@ export async function getById(id: number): Promise { return data } +/** + * Get candidate models for custom /v1/models list. + * id=0 returns platform default models for create flow. + */ +export async function getModelsListCandidates( + id: number, + platform?: GroupPlatform +): Promise { + const { data } = await apiClient.get<{ models: string[] }>( + `/admin/groups/${id}/models-list-candidates`, + { + params: platform ? { platform } : undefined + } + ) + return data.models || [] +} + /** * Create new group * @param groupData - Group data @@ -306,6 +323,7 @@ export const groupsAPI = { getAll, getByPlatform, getById, + getModelsListCandidates, create, update, delete: deleteGroup, diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts index fbba96be..aefd1618 100644 --- a/frontend/src/api/admin/riskControl.ts +++ b/frontend/src/api/admin/riskControl.ts @@ -24,6 +24,7 @@ export interface ContentModerationConfig { all_groups: boolean group_ids: number[] record_non_hits: boolean + thresholds: Record worker_count: number queue_size: number block_status: number @@ -98,6 +99,7 @@ export interface UpdateContentModerationConfig { all_groups?: boolean group_ids?: number[] record_non_hits?: boolean + thresholds?: Record worker_count?: number queue_size?: number block_status?: number @@ -130,6 +132,16 @@ export interface ContentModerationRuntimeStatus { dropped: number processed: number errors: number + pre_block_active: number + pre_block_checked: number + pre_block_allowed: number + pre_block_blocked: number + pre_block_errors: number + pre_block_avg_latency_ms: number + pre_block_api_key_active: number + pre_block_api_key_available_count: number + pre_block_api_key_total_calls: number + pre_block_api_key_loads: ContentModerationAPIKeyLoad[] api_key_statuses: ContentModerationAPIKeyStatus[] flagged_hash_count: number last_cleanup_at?: string @@ -137,6 +149,20 @@ export interface ContentModerationRuntimeStatus { last_cleanup_deleted_non_hit: number } +export interface ContentModerationAPIKeyLoad { + index: number + key_hash: string + masked: string + status: ContentModerationAPIKeyStatusValue + active: number + total: number + success: number + errors: number + avg_latency_ms: number + last_latency_ms: number + last_http_status: number +} + export interface ContentModerationLog { id: number request_id: string diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 7374d8d3..95f7be1d 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -16,6 +16,47 @@ export interface DefaultSubscriptionSetting { validity_days: number; } +// ── 平台限额类型 ────────────────────────────────────────────────── +export type PlatformType = "anthropic" | "openai" | "gemini" | "antigravity" +export type QuotaWindowType = "daily" | "weekly" | "monthly" + +/** 单平台三档限额;null = 不限制,undefined = 未填(等价 null) */ +export interface PlatformQuotaLimits { + daily: number | null + weekly: number | null + monthly: number | null +} + +/** 全平台默认限额 map(key = PlatformType) */ +export type DefaultPlatformQuotasMap = Partial> + +const PLATFORMS: PlatformType[] = ["anthropic", "openai", "gemini", "antigravity"] + +/** 归一化为全 4 平台 × 3 窗口(缺失填 null),供模板非空绑定 */ +export function normalizePlatformQuotasMap(input?: DefaultPlatformQuotasMap | null): DefaultPlatformQuotasMap { + const result: DefaultPlatformQuotasMap = {} + for (const p of PLATFORMS) { + const src = input?.[p] + result[p] = { + daily: typeof src?.daily === "number" ? src.daily : null, + weekly: typeof src?.weekly === "number" ? src.weekly : null, + monthly: typeof src?.monthly === "number" ? src.monthly : null, + } + } + return result +} + +/** 提交前清洗:非有限数/负数/空字符串 → null(保留 0 = 显式禁用),返回全 4 平台嵌套 map */ +export function sanitizePlatformQuotasMap(input?: DefaultPlatformQuotasMap | null): DefaultPlatformQuotasMap { + const clean = (v: unknown): number | null => (typeof v === "number" && Number.isFinite(v) && v >= 0 ? v : null) + const result: DefaultPlatformQuotasMap = {} + for (const p of PLATFORMS) { + const src = input?.[p] + result[p] = { daily: clean(src?.daily), weekly: clean(src?.weekly), monthly: clean(src?.monthly) } + } + return result +} + export type AuthSourceType = | "email" | "linuxdo" @@ -31,6 +72,8 @@ export interface AuthSourceDefaultsValue { subscriptions: DefaultSubscriptionSetting[]; grant_on_signup: boolean; grant_on_first_bind: boolean; + // ★ 新增:平台限额覆盖(key = PlatformType) + platform_quotas: DefaultPlatformQuotasMap; } export type AuthSourceDefaultsState = Record< @@ -193,6 +236,7 @@ export function buildAuthSourceDefaultsState( raw[`auth_source_default_${source}_grant_on_signup`] === true, grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + platform_quotas: normalizePlatformQuotasMap(raw[`auth_source_default_${source}_platform_quotas`] as DefaultPlatformQuotasMap | undefined), }; return acc; }, {} as AuthSourceDefaultsState); @@ -220,6 +264,7 @@ export function appendAuthSourceDefaultsToUpdateRequest( current.grant_on_signup; target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind; + target[`auth_source_default_${source}_platform_quotas`] = sanitizePlatformQuotasMap(current.platform_quotas) } return payload; @@ -370,6 +415,15 @@ export interface SystemSettings { auth_source_default_google_grant_on_signup?: boolean; auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; + // ── 平台限额(嵌套 JSON,系统层 + 7 auth-source 层)──────────────────────────────── + default_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_email_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_linuxdo_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_oidc_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_wechat_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_github_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_google_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_dingtalk_platform_quotas?: DefaultPlatformQuotasMap; // OEM settings site_name: string; site_logo: string; @@ -617,6 +671,15 @@ export interface UpdateSettingsRequest { auth_source_default_google_grant_on_signup?: boolean; auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; + // ── 平台限额(嵌套 JSON,系统层 + 7 auth-source 层)──────────────────────────────── + default_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_email_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_linuxdo_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_oidc_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_wechat_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_github_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_google_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_dingtalk_platform_quotas?: DefaultPlatformQuotasMap; site_name?: string; site_logo?: string; site_subtitle?: string; diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index fabc69bc..bfe5e3ba 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -297,6 +297,78 @@ export async function bindUserAuthIdentity( return data } +/** + * Platform quota types + */ +export type PlatformQuotaPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' +export type PlatformQuotaWindow = 'daily' | 'weekly' | 'monthly' + +export interface PlatformQuotaItem { + platform: PlatformQuotaPlatform + daily_limit_usd: number | null + weekly_limit_usd: number | null + monthly_limit_usd: number | null + daily_usage_usd: number + weekly_usage_usd: number + monthly_usage_usd: number + daily_window_start?: string | null + weekly_window_start?: string | null + monthly_window_start?: string | null + daily_window_resets_at?: string | null + weekly_window_resets_at?: string | null + monthly_window_resets_at?: string | null +} + +export interface PlatformQuotaUpdateItem { + platform: PlatformQuotaPlatform + daily_limit_usd: number | null + weekly_limit_usd: number | null + monthly_limit_usd: number | null +} + +export interface PlatformQuotasResponse { + platform_quotas: PlatformQuotaItem[] +} + +/** + * Get user's platform quotas + */ +export async function getPlatformQuotas(id: number): Promise { + const { data } = await apiClient.get( + `/admin/users/${id}/platform-quotas` + ) + return data +} + +/** + * Replace user's platform quotas (全量替换) + */ +export async function updatePlatformQuotas( + id: number, + quotas: PlatformQuotaUpdateItem[] +): Promise { + const { data } = await apiClient.put( + `/admin/users/${id}/platform-quotas`, + { quotas } + ) + return data +} + +/** + * Reset a single (platform, window) usage immediately + */ +export async function resetPlatformQuotaWindow( + id: number, + platform: PlatformQuotaPlatform, + window: PlatformQuotaWindow +): Promise { + const { data } = await apiClient.post( + `/admin/users/${id}/platform-quotas/reset`, + { platform, window } + ) + return data +} + export const usersAPI = { list, getById, @@ -310,7 +382,10 @@ export const usersAPI = { getUserUsageStats, getUserBalanceHistory, replaceGroup, - bindUserAuthIdentity + bindUserAuthIdentity, + getPlatformQuotas, + updatePlatformQuotas, + resetPlatformQuotaWindow, } export default usersAPI diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index da7a91eb..88768530 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -15,7 +15,8 @@ import type { NotifyEmailEntry, UserAuthProvider, UserAffiliateDetail, - AffiliateTransferResponse + AffiliateTransferResponse, + PlatformQuotasResponse, } from '@/types' /** @@ -185,6 +186,14 @@ export async function transferAffiliateQuota(): Promise { + const { data } = await apiClient.get('/user/platform-quotas') + return data +} + export const userAPI = { getProfile, updateProfile, @@ -199,7 +208,8 @@ export const userAPI = { buildOAuthBindingStartURL, startOAuthBinding, getAffiliateDetail, - transferAffiliateQuota + transferAffiliateQuota, + getMyPlatformQuotas, } export default userAPI diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index d60b5a04..235d3119 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1401,6 +1401,18 @@ }}

+
+ + +

+ {{ t('admin.accounts.poolModeRetryStatusCodesHint', { default: DEFAULT_POOL_MODE_RETRY_STATUS_CODES.join(', ') }) }} +

+
@@ -1746,6 +1758,18 @@ }}

+
+ + +

+ {{ t('admin.accounts.poolModeRetryStatusCodesHint', { default: DEFAULT_POOL_MODE_RETRY_STATUS_CODES.join(', ') }) }} +

+
@@ -3413,8 +3437,27 @@ const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) const DEFAULT_POOL_MODE_RETRY_COUNT = 3 const MAX_POOL_MODE_RETRY_COUNT = 10 +const DEFAULT_POOL_MODE_RETRY_STATUS_CODES = [401, 403, 429] const poolModeEnabled = ref(false) const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) +const poolModeRetryStatusCodesInput = ref('') + +function parsePoolModeRetryStatusCodes(input: string): number[] { + if (!input || !input.trim()) return [] + const seen = new Set() + const out: number[] = [] + for (const token of input.split(/[,\s]+/)) { + const trimmed = token.trim() + if (!trimmed) continue + const n = Number(trimmed) + if (!Number.isFinite(n) || !Number.isInteger(n)) continue + if (n < 100 || n > 599) continue + if (seen.has(n)) continue + seen.add(n) + out.push(n) + } + return out.sort((a, b) => a - b) +} const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -4216,6 +4259,7 @@ const resetForm = () => { }) poolModeEnabled.value = false poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT + poolModeRetryStatusCodesInput.value = '' customErrorCodesEnabled.value = false selectedErrorCodes.value = [] customErrorCodeInput.value = null @@ -4502,6 +4546,10 @@ const handleSubmit = async () => { if (poolModeEnabled.value) { credentials.pool_mode = true credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + const parsedRetryStatusCodes = parsePoolModeRetryStatusCodes(poolModeRetryStatusCodesInput.value) + if (parsedRetryStatusCodes.length > 0) { + credentials.pool_mode_retry_status_codes = parsedRetryStatusCodes + } } applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') @@ -4651,6 +4699,10 @@ const handleSubmit = async () => { if (poolModeEnabled.value) { credentials.pool_mode = true credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + const parsedRetryStatusCodes = parsePoolModeRetryStatusCodes(poolModeRetryStatusCodesInput.value) + if (parsedRetryStatusCodes.length > 0) { + credentials.pool_mode_retry_status_codes = parsedRetryStatusCodes + } } // Add custom error codes if enabled diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 070887fe..f44b5d38 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -305,6 +305,18 @@ }}

+
+ + +

+ {{ t('admin.accounts.poolModeRetryStatusCodesHint', { default: DEFAULT_POOL_MODE_RETRY_STATUS_CODES.join(', ') }) }} +

+
@@ -973,6 +985,18 @@ }}

+
+ + +

+ {{ t('admin.accounts.poolModeRetryStatusCodesHint', { default: DEFAULT_POOL_MODE_RETRY_STATUS_CODES.join(', ') }) }} +

+
@@ -2317,8 +2341,42 @@ const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) const DEFAULT_POOL_MODE_RETRY_COUNT = 3 const MAX_POOL_MODE_RETRY_COUNT = 10 +const DEFAULT_POOL_MODE_RETRY_STATUS_CODES = [401, 403, 429] const poolModeEnabled = ref(false) const poolModeRetryCount = ref(DEFAULT_POOL_MODE_RETRY_COUNT) +const poolModeRetryStatusCodesInput = ref('') + +function parsePoolModeRetryStatusCodes(input: string): number[] { + if (!input || !input.trim()) return [] + const seen = new Set() + const out: number[] = [] + for (const token of input.split(/[,\s]+/)) { + const trimmed = token.trim() + if (!trimmed) continue + const n = Number(trimmed) + if (!Number.isFinite(n) || !Number.isInteger(n)) continue + if (n < 100 || n > 599) continue + if (seen.has(n)) continue + seen.add(n) + out.push(n) + } + return out.sort((a, b) => a - b) +} + +function formatPoolModeRetryStatusCodes(value: unknown): string { + if (!Array.isArray(value)) return '' + const out: number[] = [] + const seen = new Set() + for (const v of value) { + const n = typeof v === 'string' ? Number(v.trim()) : Number(v) + if (!Number.isFinite(n) || !Number.isInteger(n)) continue + if (n < 100 || n > 599) continue + if (seen.has(n)) continue + seen.add(n) + out.push(n) + } + return out.sort((a, b) => a - b).join(', ') +} const customErrorCodesEnabled = ref(false) const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) @@ -2807,6 +2865,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { poolModeRetryCount.value = normalizePoolModeRetryCount( Number(credentials.pool_mode_retry_count ?? DEFAULT_POOL_MODE_RETRY_COUNT) ) + poolModeRetryStatusCodesInput.value = formatPoolModeRetryStatusCodes(credentials.pool_mode_retry_status_codes) // Load custom error codes customErrorCodesEnabled.value = credentials.custom_error_codes_enabled === true @@ -2834,6 +2893,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { poolModeEnabled.value = bedrockCreds.pool_mode === true const retryCount = bedrockCreds.pool_mode_retry_count poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT + poolModeRetryStatusCodesInput.value = formatPoolModeRetryStatusCodes(bedrockCreds.pool_mode_retry_status_codes) // Load quota limits for bedrock const bedrockExtra = (newAccount.extra as Record) || {} @@ -2876,6 +2936,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { } poolModeEnabled.value = false poolModeRetryCount.value = DEFAULT_POOL_MODE_RETRY_COUNT + poolModeRetryStatusCodesInput.value = '' customErrorCodesEnabled.value = false selectedErrorCodes.value = [] } @@ -3427,9 +3488,16 @@ const handleSubmit = async () => { if (poolModeEnabled.value) { newCredentials.pool_mode = true newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + const parsedRetryStatusCodes = parsePoolModeRetryStatusCodes(poolModeRetryStatusCodesInput.value) + if (parsedRetryStatusCodes.length > 0) { + newCredentials.pool_mode_retry_status_codes = parsedRetryStatusCodes + } else { + delete newCredentials.pool_mode_retry_status_codes + } } else { delete newCredentials.pool_mode delete newCredentials.pool_mode_retry_count + delete newCredentials.pool_mode_retry_status_codes } // Add custom error codes if enabled @@ -3545,9 +3613,16 @@ const handleSubmit = async () => { if (poolModeEnabled.value) { newCredentials.pool_mode = true newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value) + const parsedRetryStatusCodes = parsePoolModeRetryStatusCodes(poolModeRetryStatusCodesInput.value) + if (parsedRetryStatusCodes.length > 0) { + newCredentials.pool_mode_retry_status_codes = parsedRetryStatusCodes + } else { + delete newCredentials.pool_mode_retry_status_codes + } } else { delete newCredentials.pool_mode delete newCredentials.pool_mode_retry_count + delete newCredentials.pool_mode_retry_status_codes } // Model mapping diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index 637d6011..b7178541 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -371,16 +371,12 @@ const handleExchangeCode = async () => { const extra = oauthClient.buildExtraInfo(tokenInfo) try { - // Update account with new credentials - await adminAPI.accounts.update(props.account.id, { - type: 'oauth', // OpenAI OAuth is always 'oauth' type + const updatedAccount = await adminAPI.accounts.applyOAuthCredentials(props.account.id, { + type: 'oauth', credentials, extra }) - // Clear error status after successful re-authorization - const updatedAccount = await adminAPI.accounts.clearError(props.account.id) - appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) emit('reauthorized', updatedAccount) handleClose() @@ -476,16 +472,12 @@ const handleExchangeCode = async () => { const extra = claudeOAuth.buildExtraInfo(tokenInfo) - // Update account with new credentials and type - await adminAPI.accounts.update(props.account.id, { - type: addMethod.value, // Update type based on selected method - credentials: tokenInfo, + const updatedAccount = await adminAPI.accounts.applyOAuthCredentials(props.account.id, { + type: addMethod.value as 'oauth' | 'setup-token', + credentials: tokenInfo as unknown as Record, extra }) - // Clear error status after successful re-authorization - const updatedAccount = await adminAPI.accounts.clearError(props.account.id) - appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) emit('reauthorized', updatedAccount) handleClose() @@ -519,16 +511,12 @@ const handleCookieAuth = async (sessionKey: string) => { const extra = claudeOAuth.buildExtraInfo(tokenInfo) - // Update account with new credentials and type - await adminAPI.accounts.update(props.account.id, { - type: addMethod.value, // Update type based on selected method - credentials: tokenInfo, + const updatedAccount = await adminAPI.accounts.applyOAuthCredentials(props.account.id, { + type: addMethod.value as 'oauth' | 'setup-token', + credentials: tokenInfo as unknown as Record, extra }) - // Clear error status after successful re-authorization - const updatedAccount = await adminAPI.accounts.clearError(props.account.id) - appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) emit('reauthorized', updatedAccount) handleClose() diff --git a/frontend/src/components/admin/user/UserPlatformQuotaModal.vue b/frontend/src/components/admin/user/UserPlatformQuotaModal.vue new file mode 100644 index 00000000..76d23b0c --- /dev/null +++ b/frontend/src/components/admin/user/UserPlatformQuotaModal.vue @@ -0,0 +1,283 @@ + + + diff --git a/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts b/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts new file mode 100644 index 00000000..06248133 --- /dev/null +++ b/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts @@ -0,0 +1,245 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { createPinia, setActivePinia } from 'pinia' + +const apiMocks = vi.hoisted(() => ({ + getPlatformQuotas: vi.fn(), + updatePlatformQuotas: vi.fn(), + resetPlatformQuotaWindow: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + users: { + getPlatformQuotas: apiMocks.getPlatformQuotas, + updatePlatformQuotas: apiMocks.updatePlatformQuotas, + resetPlatformQuotaWindow: apiMocks.resetPlatformQuotaWindow, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError: vi.fn(), + showSuccess: vi.fn(), + }), +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (params) { + return key.replace(/\{(\w+)\}/g, (_, k) => params[k] ?? '') + } + return key + }, + }), + } +}) + +vi.mock('@/components/common/BaseDialog.vue', () => ({ + default: { + name: 'BaseDialog', + props: ['show', 'title', 'width'], + template: '
', + }, +})) + +import UserPlatformQuotaModal from '../UserPlatformQuotaModal.vue' +import type { UserSubscription } from '@/types' + +function makeUser(overrides: { subscriptions?: UserSubscription[] } = {}) { + return { id: 99, email: 'u@example.com', ...overrides } as any +} + +/** 挂载并触发 show:false → true,确保 watch 被激活 */ +async function mountAndOpen(extraProps: Record = {}) { + const w = mount(UserPlatformQuotaModal, { + props: { show: false, user: makeUser(), ...extraProps }, + }) + await w.setProps({ show: true }) + await flushPromises() + return w +} + +beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + apiMocks.getPlatformQuotas.mockResolvedValue({ platform_quotas: [] }) + apiMocks.updatePlatformQuotas.mockResolvedValue({ platform_quotas: [] }) + apiMocks.resetPlatformQuotaWindow.mockResolvedValue({ platform_quotas: [] }) +}) + +describe('UserPlatformQuotaModal', () => { + it('挂载并 show=true 时调用 getPlatformQuotas', async () => { + await mountAndOpen() + expect(apiMocks.getPlatformQuotas).toHaveBeenCalledWith(99) + }) + + it('空数据渲染 4 个 platform 行', async () => { + const w = await mountAndOpen() + const html = w.html() + expect(html).toContain('anthropic') + expect(html).toContain('openai') + expect(html).toContain('gemini') + expect(html).toContain('antigravity') + }) + + it('已有数据正确填充 limit input', async () => { + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: null, monthly_limit_usd: null, + daily_usage_usd: 3.2, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const inputs = w.findAll('input[type=number]') + // 4 platforms × 3 windows = 12 inputs + expect(inputs.length).toBe(12) + // 第一个 input 是 anthropic.daily = 10 + expect((inputs[0].element as HTMLInputElement).value).toBe('10') + }) + + it('保存提交完整 4 platform payload', async () => { + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'openai', daily_limit_usd: null, weekly_limit_usd: 20, monthly_limit_usd: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + // 找到「保存」按钮(包含中文「保存」字样的按钮) + const buttons = w.findAll('button') + const saveBtn = buttons.find((b) => b.text() === 'admin.users.platformQuota.save') + expect(saveBtn).toBeTruthy() + await saveBtn!.trigger('click') + await flushPromises() + expect(apiMocks.updatePlatformQuotas).toHaveBeenCalledTimes(1) + const [uid, payload] = apiMocks.updatePlatformQuotas.mock.calls[0] + expect(uid).toBe(99) + expect(payload).toHaveLength(4) // 4 platforms always submitted + const openai = payload.find((p: any) => p.platform === 'openai') + expect(openai.weekly_limit_usd).toBe(20) + }) + + it('全部清空把所有 limit 置 null(确认通过)', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true) + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: 50, monthly_limit_usd: 100, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const buttons = w.findAll('button') + const clearBtn = buttons.find((b) => b.text() === 'admin.users.platformQuota.clearAll') + expect(clearBtn).toBeTruthy() + await clearBtn!.trigger('click') + await flushPromises() + expect(confirmSpy).toHaveBeenCalledTimes(1) + const inputs = w.findAll('input[type=number]') + for (const inp of inputs) { + expect((inp.element as HTMLInputElement).value).toBe('') + } + confirmSpy.mockRestore() + }) + + it('全部清空 confirm 取消则保持原值', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false) + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: 50, monthly_limit_usd: 100, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const clearBtn = w.findAll('button').find((b) => b.text() === 'admin.users.platformQuota.clearAll') + await clearBtn!.trigger('click') + await flushPromises() + expect(confirmSpy).toHaveBeenCalledTimes(1) + // anthropic daily 应保持 10(未被清空) + const inputs = w.findAll('input[type=number]') + const dailyVal = (inputs[0].element as HTMLInputElement).value + expect(dailyVal).toBe('10') + confirmSpy.mockRestore() + }) + + it('重置按钮 confirm 取消则不调用 API', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false) + const w = await mountAndOpen() + const resetBtns = w.findAll('button').filter((b) => b.text() === '↻') + expect(resetBtns.length).toBeGreaterThan(0) + await resetBtns[0].trigger('click') + await flushPromises() + expect(apiMocks.resetPlatformQuotaWindow).not.toHaveBeenCalled() + confirmSpy.mockRestore() + }) + + it('重置按钮 confirm 确认则调用 API', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true) + const w = await mountAndOpen() + const resetBtns = w.findAll('button').filter((b) => b.text() === '↻') + await resetBtns[0].trigger('click') // 第一个是 anthropic.daily + await flushPromises() + expect(apiMocks.resetPlatformQuotaWindow).toHaveBeenCalledWith(99, 'anthropic', 'daily') + confirmSpy.mockRestore() + }) + + describe('subscription warning banner', () => { + it('displays subscription warning when user has active subscription', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ + subscriptions: [ + { + id: 1, user_id: 99, group_id: 1, status: 'active', + starts_at: '2026-01-01T00:00:00Z', expires_at: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + daily_window_start: null, weekly_window_start: null, monthly_window_start: null, + created_at: '2026-01-01T00:00:00Z', updated_at: '2026-01-01T00:00:00Z', + } as UserSubscription, + ], + }), + }, + }) + await flushPromises() + expect(w.html()).toContain('admin.users.platformQuota.subscriptionWarning') + }) + + it('hides subscription warning when user has only expired subscriptions', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ + subscriptions: [ + { + id: 2, user_id: 99, group_id: 1, status: 'expired', + starts_at: '2025-01-01T00:00:00Z', expires_at: '2025-12-31T00:00:00Z', + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + daily_window_start: null, weekly_window_start: null, monthly_window_start: null, + created_at: '2025-01-01T00:00:00Z', updated_at: '2025-12-31T00:00:00Z', + } as UserSubscription, + ], + }), + }, + }) + await flushPromises() + expect(w.html()).not.toContain('admin.users.platformQuota.subscriptionWarning') + }) + + it('hides subscription warning when subscriptions is empty array', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ subscriptions: [] }), + }, + }) + await flushPromises() + expect(w.html()).not.toContain('admin.users.platformQuota.subscriptionWarning') + }) + }) +}) diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index e3729cc8..1269ae39 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -547,20 +547,21 @@ function generateOpenAIFiles(baseUrl: string, apiKey: string): FileConfig[] { // config.toml content const configContent = `model_provider = "OpenAI" -model = "gpt-5.4" -review_model = "gpt-5.4" +model = "gpt-5.5" +review_model = "gpt-5.5" model_reasoning_effort = "xhigh" disable_response_storage = true network_access = "enabled" windows_wsl_setup_acknowledged = true -model_context_window = 1000000 -model_auto_compact_token_limit = 900000 [model_providers.OpenAI] name = "OpenAI" base_url = "${baseUrl}" wire_api = "responses" -requires_openai_auth = true` +requires_openai_auth = true + +[features] +goals = true` // auth.json content const authContent = `{ @@ -586,14 +587,12 @@ function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] { // config.toml content with WebSocket v2 const configContent = `model_provider = "OpenAI" -model = "gpt-5.4" -review_model = "gpt-5.4" +model = "gpt-5.5" +review_model = "gpt-5.5" model_reasoning_effort = "xhigh" disable_response_storage = true network_access = "enabled" windows_wsl_setup_acknowledged = true -model_context_window = 1000000 -model_auto_compact_token_limit = 900000 [model_providers.OpenAI] name = "OpenAI" @@ -603,7 +602,8 @@ supports_websockets = true requires_openai_auth = true [features] -responses_websockets_v2 = true` +responses_websockets_v2 = true +goals = true` // auth.json content const authContent = `{ diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts index f7db586a..b3fdeb93 100644 --- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts +++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts @@ -17,6 +17,78 @@ vi.mock('@/composables/useClipboard', () => ({ import UseKeyModal from '../UseKeyModal.vue' describe('UseKeyModal', () => { + it('renders GPT-5.5 and goals feature in OpenAI Codex config', () => { + const wrapper = mount(UseKeyModal, { + props: { + show: true, + apiKey: 'sk-test', + baseUrl: 'https://example.com/v1', + platform: 'openai' + }, + global: { + stubs: { + BaseDialog: { + template: '
' + }, + Icon: { + template: '' + } + } + } + }) + + const codeBlocks = wrapper.findAll('pre code').map((code) => code.text()) + const configToml = codeBlocks.find((content) => content.includes('model_provider = "OpenAI"')) + + expect(configToml).toBeDefined() + expect(configToml).toContain('model = "gpt-5.5"') + expect(configToml).toContain('review_model = "gpt-5.5"') + expect(configToml).not.toContain('model = "gpt-5.4"') + expect(configToml).not.toContain('model_context_window') + expect(configToml).not.toContain('model_auto_compact_token_limit') + expect(configToml).toContain('[features]\ngoals = true') + }) + + it('renders GPT-5.5 and goals feature in OpenAI Codex WebSocket config', async () => { + const wrapper = mount(UseKeyModal, { + props: { + show: true, + apiKey: 'sk-test', + baseUrl: 'https://example.com/v1', + platform: 'openai' + }, + global: { + stubs: { + BaseDialog: { + template: '
' + }, + Icon: { + template: '' + } + } + } + }) + + const wsTab = wrapper.findAll('button').find((button) => + button.text().includes('keys.useKeyModal.cliTabs.codexCliWs') + ) + + expect(wsTab).toBeDefined() + await wsTab!.trigger('click') + await nextTick() + + const codeBlocks = wrapper.findAll('pre code').map((code) => code.text()) + const configToml = codeBlocks.find((content) => content.includes('supports_websockets = true')) + + expect(configToml).toBeDefined() + expect(configToml).toContain('model = "gpt-5.5"') + expect(configToml).toContain('review_model = "gpt-5.5"') + expect(configToml).not.toContain('model = "gpt-5.4"') + expect(configToml).not.toContain('model_context_window') + expect(configToml).not.toContain('model_auto_compact_token_limit') + expect(configToml).toContain('[features]\nresponses_websockets_v2 = true\ngoals = true') + }) + it('renders GPT-5.4 mini entry in OpenCode config', async () => { const wrapper = mount(UseKeyModal, { props: { diff --git a/frontend/src/components/user/UserPlatformQuotaCell.vue b/frontend/src/components/user/UserPlatformQuotaCell.vue new file mode 100644 index 00000000..376b61c8 --- /dev/null +++ b/frontend/src/components/user/UserPlatformQuotaCell.vue @@ -0,0 +1,61 @@ + + + diff --git a/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts b/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts new file mode 100644 index 00000000..533d1a32 --- /dev/null +++ b/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts @@ -0,0 +1,74 @@ +import { describe, it, expect, vi } from 'vitest' +import { mount } from '@vue/test-utils' + +// t() 回显 key,便于断言文案键 +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ t: (key: string) => key }), + } +}) + +import UserPlatformQuotaCell from '../UserPlatformQuotaCell.vue' +import type { PlatformQuotaItem } from '@/api/admin/users' + +function item(over: Partial & { platform: PlatformQuotaItem['platform'] }): PlatformQuotaItem { + return { + daily_limit_usd: null, weekly_limit_usd: null, monthly_limit_usd: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + ...over, + } as PlatformQuotaItem +} + +describe('UserPlatformQuotaCell', () => { + it('quotas 为 undefined 时渲染加载占位 …', () => { + const w = mount(UserPlatformQuotaCell, { props: { quotas: undefined } }) + expect(w.text()).toContain('…') + expect(w.html()).not.toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('空数组渲染「未配置」', () => { + const w = mount(UserPlatformQuotaCell, { props: { quotas: [] } }) + expect(w.html()).toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('平台有记录但全部 limit 为 null 时视为未配置', () => { + const w = mount(UserPlatformQuotaCell, { + props: { quotas: [item({ platform: 'openai', daily_usage_usd: 5 })] }, + }) + expect(w.html()).toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('已配置平台渲染 已用/限额,null 档显示 —,金额去尾零', () => { + const w = mount(UserPlatformQuotaCell, { + props: { + quotas: [ + item({ platform: 'anthropic', daily_limit_usd: 100, daily_usage_usd: 30, + weekly_limit_usd: null, weekly_usage_usd: 0, + monthly_limit_usd: 2000, monthly_usage_usd: 90.5 }), + ], + }, + }) + const html = w.html() + expect(html).toContain('anthropic') + expect(html).toContain('30/100') + expect(html).toContain('0/—') + expect(html).toContain('90.5/2000') + }) + + it('多平台按 anthropic→openai→gemini→antigravity 顺序,仅展示有限额的', () => { + const w = mount(UserPlatformQuotaCell, { + props: { + quotas: [ + item({ platform: 'gemini', monthly_limit_usd: 50 }), + item({ platform: 'anthropic', daily_limit_usd: 10 }), + item({ platform: 'openai', daily_usage_usd: 9 }), + ], + }, + }) + const text = w.text() + expect(text.indexOf('anthropic')).toBeLessThan(text.indexOf('gemini')) + expect(text).not.toContain('openai') + }) +}) diff --git a/frontend/src/components/user/dashboard/UserDashboardStats.vue b/frontend/src/components/user/dashboard/UserDashboardStats.vue index 97d2da3d..e7f3449d 100644 --- a/frontend/src/components/user/dashboard/UserDashboardStats.vue +++ b/frontend/src/components/user/dashboard/UserDashboardStats.vue @@ -177,6 +177,46 @@ + + +
+

+ {{ t('dashboard.platformQuota.title') }} +

+ +
@@ -187,11 +227,23 @@ import { computed } from 'vue' import { useI18n } from 'vue-i18n' import Icon from '@/components/icons/Icon.vue' import type { UserDashboardStats as UserStatsType } from '@/api/usage' +import type { PlatformQuotaItem } from '@/types' + +interface FusedPlatformCard { + platform: string + total_actual_cost: number + today_actual_cost: number + total_requests: number + total_tokens: number + isOther?: boolean + quota?: PlatformQuotaItem +} const props = defineProps<{ stats: UserStatsType balance: number isSimple: boolean + platformQuotas?: PlatformQuotaItem[] | null }>() const { t } = useI18n() @@ -213,16 +265,45 @@ const sortedPlatforms = computed(() => { // (group 与 account 都缺 platform)。这里把差值作为"其他"卡片显式展示, // 避免 Row 1 总值与 Row 3 平台拆分加总对不上、用户困惑。 const OTHER_THRESHOLD = 0.0001 -const platformCards = computed(() => { - const cards: Array<{ - platform: string - total_actual_cost: number - today_actual_cost: number - total_requests: number - total_tokens: number - isOther?: boolean - }> = sortedPlatforms.value.map((p) => ({ ...p })) +const platformCards = computed(() => { + // 建立 by_platform Map + const byPlat = new Map() + for (const item of props.stats?.by_platform ?? []) byPlat.set(item.platform, item) + // 建立 quota Map + const byQuota = new Map() + for (const q of props.platformQuotas ?? []) byQuota.set(q.platform, q) + + // union 平台集合。后端 by_platform / quota 接口均不会返回 platform='__other__', + // 无需显式排除;__other__ 由下方差值补差逻辑单独追加。 + const platforms = new Set([...byPlat.keys(), ...byQuota.keys()]) + + const PLATFORM_ORDER = ['anthropic', 'openai', 'gemini', 'antigravity'] + const cards: FusedPlatformCard[] = [] + + for (const p of platforms) { + const stat = byPlat.get(p) + cards.push({ + platform: p, + total_actual_cost: stat?.total_actual_cost ?? 0, + today_actual_cost: stat?.today_actual_cost ?? 0, + total_requests: stat?.total_requests ?? 0, + total_tokens: stat?.total_tokens ?? 0, + quota: byQuota.get(p), + }) + } + + // 排序:按 PLATFORM_ORDER,未知平台按名称排序 + cards.sort((a, b) => { + const ai = PLATFORM_ORDER.indexOf(a.platform) + const bi = PLATFORM_ORDER.indexOf(b.platform) + if (ai === -1 && bi === -1) return a.platform.localeCompare(b.platform) + if (ai === -1) return 1 + if (bi === -1) return -1 + return ai - bi + }) + + // __other__ 补差逻辑:只对 by_platform 有 usage 数据的总和计算 const total = props.stats?.total_actual_cost ?? 0 const today = props.stats?.today_actual_cost ?? 0 const sumTotal = cards.reduce((s, c) => s + c.total_actual_cost, 0) @@ -237,12 +318,62 @@ const platformCards = computed(() => { today_actual_cost: diffToday, total_requests: 0, total_tokens: 0, - isOther: true + isOther: true, }) } + return cards }) +// Quota helpers + +type QuotaWindow = 'daily' | 'weekly' | 'monthly' +type QuotaField = `${QuotaWindow}_limit_usd` | `${QuotaWindow}_usage_usd` | `${QuotaWindow}_window_resets_at` + +function quotaVal(q: PlatformQuotaItem | undefined, key: QuotaField): PlatformQuotaItem[QuotaField] { + return q?.[key] +} + +function hasAnyLimit(q: PlatformQuotaItem | undefined): boolean { + if (!q) return false + return q.daily_limit_usd != null || q.weekly_limit_usd != null || q.monthly_limit_usd != null +} + +function calcPercent(usage: number, limit: number): number { + if (!limit || limit <= 0) return 0 + return Math.min(100, Math.max(0, Math.round((usage / limit) * 100))) +} + +function quotaBarClass(p: number): string { + if (p >= 95) return 'bg-red-500' + if (p >= 75) return 'bg-amber-500' + return 'bg-green-500' +} + +// 与 formatBalance 一致使用 Intl.NumberFormat 做半偶舍入,避免 toFixed 在不同 JS 引擎 +// 下偶发截断而非四舍五入(与后端展示精度不一致)。 +const usdFormatter = new Intl.NumberFormat('en-US', { + minimumFractionDigits: 2, + maximumFractionDigits: 2, +}) +function formatUsd(n: number): string { + if (!Number.isFinite(n)) return '0.00' + return usdFormatter.format(n) +} + +function formatResetTime(iso: string | null | undefined): string { + if (!iso) return '' + const d = new Date(iso) + if (Number.isNaN(d.getTime())) return iso + return d.toLocaleString(undefined, { + month: 'numeric', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + hour12: false, + }) +} + const formatBalance = (b: number) => new Intl.NumberFormat('en-US', { minimumFractionDigits: 2, diff --git a/frontend/src/i18n/__tests__/riskControlLocales.spec.ts b/frontend/src/i18n/__tests__/riskControlLocales.spec.ts new file mode 100644 index 00000000..eab94fe6 --- /dev/null +++ b/frontend/src/i18n/__tests__/riskControlLocales.spec.ts @@ -0,0 +1,24 @@ +import { describe, expect, it } from 'vitest' + +import en from '../locales/en' +import zh from '../locales/zh' + +describe('risk control locale copy', () => { + it('describes worker runtime as audit and pre-block record processing', () => { + expect(zh.admin.riskControl.workerStatusHint).toContain('前置拦截记录任务') + expect(zh.admin.riskControl.workerStatusHint).not.toContain('异步观察任务') + expect(en.admin.riskControl.workerStatusHint).toContain('pre-block record tasks') + expect(en.admin.riskControl.workerStatusHint).not.toContain('observation tasks') + }) + + it('keeps pre-block audit key summary aware of async worker load', () => { + expect(zh.admin.riskControl.preBlockAPIKeyLoadSummary).toContain('worker:{workerActive} / {workerTotal}') + expect(en.admin.riskControl.preBlockAPIKeyLoadSummary).toContain('worker: {workerActive} / {workerTotal}') + }) + + it('does not describe pre-block audit key polling as bypassing the worker pool', () => { + expect(zh.admin.riskControl.preBlockAPIKeyLoadHint).toBe('同步前置拦截直接轮询可用审核 Key。') + expect(zh.admin.riskControl.preBlockAPIKeyLoadHint).not.toContain('Worker 池') + expect(en.admin.riskControl.preBlockAPIKeyLoadHint).not.toContain('worker pool') + }) +}) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 7a8bb607..abd0e6e7 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -635,6 +635,15 @@ export default { platformBreakdownEmpty: 'No platform usage yet', platformCount: '{count} platforms', platformOther: 'Other', + platformQuota: { + title: 'Quota Usage', + daily: 'Daily', + weekly: 'Weekly', + monthly: 'Monthly (30-day rolling)', + resetsAt: 'Resets {time}', + noLimit: 'unlimited', + disabled: 'Disabled', + }, tokenUsageTrend: 'Token Usage Trend', noDataAvailable: 'No data available', model: 'Model', @@ -1794,6 +1803,7 @@ export default { groups: 'Groups', subscriptions: 'Subscriptions', balance: 'Balance', + balancePlatformQuota: 'Balance (Platform Quota)', usage: 'Usage', usageAnthropic: 'Usage (Claude)', usageOpenAI: 'Usage (OpenAI)', @@ -1986,6 +1996,41 @@ export default { failedToReorder: 'Failed to update order', keyExists: 'Attribute key already exists', dragToReorder: 'Drag to reorder' + }, + platformQuota: { + menuItem: 'Platform Quotas', + title: 'Platform Quotas', + subtitle: 'Configure daily / weekly / monthly USD usage limits for each upstream platform for user {email}', + columns: { + platform: 'Platform', + daily: 'Daily (USD)', + weekly: 'Weekly (USD)', + monthly: 'Monthly (USD, 30-day rolling)', + usage: 'Current Usage', + }, + placeholder: 'unlimited', + save: 'Save', + saving: 'Saving...', + cancel: 'Cancel', + clearAll: 'Clear All (remove all limits)', + clearAllConfirm: 'Clear daily / weekly / monthly limits for ALL platforms? All platforms will become "unlimited" with no local undo — you must manually re-enter values before saving.', + reset: { + button: 'Reset window', + confirm: 'Reset the {window} usage for {platform} for this user? This is effective immediately.', + success: 'Reset {platform} {window} usage', + failed: 'Reset failed', + }, + updateSuccess: 'Platform quotas updated', + updateFailed: 'Save failed', + loadFailed: 'Load failed', + hint: 'Empty = no limit for that window.', + windowDaily: 'daily', + windowWeekly: 'weekly', + windowMonthly: 'monthly', + cellNotConfigured: 'Not configured', + cellColumnTooltip: 'Only platforms with a limit are shown', + subscriptionWarning: 'This user has an active subscription. Platform quotas only apply to balance (standard) mode requests; subscription mode requests are not subject to these limits.', + invalidNumber: 'The following fields contain invalid numbers. Please fix them before saving: {fields}', } }, @@ -2136,6 +2181,12 @@ export default { finalPricePreview: 'Final per-image price preview', notConfigured: 'Not configured' }, + modelsList: { + title: 'Custom /v1/models Model List', + hint: 'Only changes the /v1/models response. Whitelist model calls and account routing are unchanged.', + loading: 'Loading model list...', + empty: 'No displayable models' + }, claudeCode: { title: 'Claude Code Client Restriction', tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.', @@ -2550,14 +2601,37 @@ export default { modelFilterIncludeSummary: 'Applies to {count} models', modelFilterExcludeSummary: 'Excludes {count} models', emptyLogs: 'No audit records', + preBlockSyncStatus: 'Pre-Block Sync Status', + preBlockSyncHint: 'Live counters for the synchronous moderation path, excluding async record tasks.', + preBlockActive: 'Sync Processing', + preBlockActiveHint: 'Currently checking', + preBlockChecked: 'Checked', + preBlockCheckedHint: 'Entered pre-block path', + preBlockAllowed: 'Allowed', + preBlockAllowedHint: 'No block triggered', + preBlockBlocked: 'Blocked', + preBlockBlockedHint: 'Rejected after hit', + preBlockErrors: 'Audit Errors', + preBlockErrorsHint: 'Failed or no usable key', + preBlockAvgLatency: 'Avg Latency', + preBlockAvgLatencyHint: 'Synchronous path average', + preBlockAPIKeyLoad: 'Audit Key Load', + preBlockAPIKeyLoadHint: 'Synchronous pre-block checks round-robin usable audit keys directly.', + preBlockAPIKeyLoadSummary: 'Sync active {active} / usable keys {available}, {total} total, worker: {workerActive} / {workerTotal}', + preBlockAPIKeyTotals: 'Total {total}, success {success}, errors {errors}', + preBlockAPIKeyLoadEmpty: 'No audit key load data yet', + preBlockKeyActiveShort: 'Active', + preBlockKeyTotalShort: 'Total', + preBlockKeyAvgShort: 'Avg', + preBlockKeyLastShort: 'Last', workerStatus: 'Worker Runtime', - workerStatusHint: 'Queue and worker pool status for asynchronous observation tasks.', + workerStatusHint: 'Queue and worker pool status for async audit tasks and pre-block record tasks, excluding synchronous pre-block checks.', workerPool: 'Worker Pool', workerPoolMeta: '{active} processing, {idle} idle and ready, {total} total', queueUsage: 'Queue Usage', activeWorkers: 'Processing', idleWorkers: 'Idle Ready', - workerActive: 'Processing an asynchronous audit task', + workerActive: 'Processing an async audit or record task', workerIdle: 'Started, idle and ready', workerDisabled: 'Risk control or content audit is disabled', processed: 'Processed', @@ -2566,11 +2640,17 @@ export default { lastCleanup: 'Last cleanup: {time}', cleanupStats: 'Last cleanup deleted {hit} hits and {nonHit} non-hits', riskSwitchOff: 'System switch off', + riskThresholds: 'Risk Thresholds', + riskThresholdsHint: 'Adjust hit thresholds by OpenAI Moderations category. Scores greater than or equal to the threshold count as hits.', + riskThresholdDefault: 'Default {value}', + riskThresholdReset: 'Restore defaults', + riskThresholdPercent: 'Threshold percentage', tabs: { basic: 'Basic', scope: 'Scope', runtime: 'Runtime', response: 'Hit Notice', + riskThresholds: 'Risk Thresholds', keywords: 'Keyword Block', retention: 'Retention', }, @@ -3027,6 +3107,7 @@ export default { usageWindows: 'Usage Windows', proxy: 'Proxy', lastUsed: 'Last Used', + createdAt: 'Created', expiresAt: 'Expires At', actions: 'Actions' }, @@ -3371,6 +3452,9 @@ export default { poolModeRetryCount: 'Same-Account Retries', poolModeRetryCountHint: 'Only applies in pool mode. Use 0 to disable in-place retry. Default {default}, maximum {max}.', + poolModeRetryStatusCodes: 'Retry Status Codes', + poolModeRetryStatusCodesHint: + 'Comma-separated HTTP status codes (100-599) that trigger same-account retry in pool mode. Leave blank to use defaults ({default}).', customErrorCodes: 'Custom Error Codes', customErrorCodesHint: 'Only stop scheduling for selected error codes', customErrorCodesWarning: @@ -5494,7 +5578,17 @@ export default { defaultSubscriptionsDuplicate: 'Duplicate subscription group: {groupId}. Each group can only appear once.', subscriptionGroup: 'Subscription Group', - subscriptionValidityDays: 'Validity (days)' + subscriptionValidityDays: 'Validity (days)', + defaultPlatformQuotas: 'Default Platform Quotas (on signup)', + defaultPlatformQuotasHint: 'Automatically assigned to new users on signup; existing users are not affected. Leave blank = unlimited.', + platformQuotaNotice: 'Monthly quota uses a 30-day rolling window, not a calendar month.', + }, + platformQuota: { + platform: 'Platform', + daily: 'Daily (USD)', + weekly: 'Weekly (USD)', + monthly: 'Monthly (USD, 30d rolling)', + placeholder: 'Unlimited', }, claudeCode: { title: 'Claude Code Settings', @@ -6215,7 +6309,9 @@ export default { grantOnFirstBindHint: 'Grant default entitlements when an existing user first binds this source.', defaultSubscriptionsLabel: 'Default subscriptions', defaultSubscriptionsHint: 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.', - noSourceSubscriptions: 'No source-specific default subscriptions configured.' + noSourceSubscriptions: 'No source-specific default subscriptions configured.', + platformQuotasOverride: 'Platform Quota Overrides', + platformQuotasOverrideHint: 'Blank fields inherit the system default. Set to 0 to fully block that window for this auth source.', }, paymentVisibleMethods: { methodLabel: '{title} visible method', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index b23caa8a..7f0c7fc6 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -634,6 +634,15 @@ export default { platformBreakdownEmpty: '暂无平台用量', platformCount: '{count} 个平台', platformOther: '其他', + platformQuota: { + title: '配额用量', + daily: '日', + weekly: '周', + monthly: '月(近30天)', + resetsAt: '{time} 重置', + noLimit: '不限制', + disabled: '已禁用', + }, tokenUsageTrend: 'Token 使用趋势', noDataAvailable: '暂无数据', model: '模型', @@ -1815,6 +1824,7 @@ export default { groups: '分组', subscriptions: '订阅分组', balance: '余额', + balancePlatformQuota: '余额(平台配额)', usage: '用量', usageAnthropic: '用量 (Claude)', usageOpenAI: '用量 (OpenAI)', @@ -2039,6 +2049,41 @@ export default { failedToReorder: '更新排序失败', keyExists: '属性键已存在', dragToReorder: '拖拽排序' + }, + platformQuota: { + menuItem: '平台限额', + title: '平台限额', + subtitle: '为用户 {email} 配置各上游平台的日 / 周 / 月用量上限', + columns: { + platform: '平台', + daily: '日 (USD)', + weekly: '周 (USD)', + monthly: '月 (USD, 30天滚动)', + usage: '当前用量', + }, + placeholder: '不限制', + save: '保存', + saving: '保存中...', + cancel: '取消', + clearAll: '全部清空(取消所有限额)', + clearAllConfirm: '确认清空全部平台的日 / 周 / 月限额?所有平台将变为"无限额",本地无法撤销,需要在保存前手动重填。', + reset: { + button: '重置该窗口', + confirm: '确认重置该用户 {platform} 平台的 {window} 用量?此操作立即生效。', + success: '已重置 {platform} {window} 用量', + failed: '重置失败', + }, + updateSuccess: '平台限额已更新', + updateFailed: '保存失败', + loadFailed: '加载失败', + hint: '留空 = 不限制该窗口。', + windowDaily: '日', + windowWeekly: '周', + windowMonthly: '月', + cellNotConfigured: '未配置', + cellColumnTooltip: '仅展示已设限额的平台', + subscriptionWarning: '此用户有活跃订阅,平台限额仅在余额(标准)模式下生效,订阅模式请求不受此限额约束。', + invalidNumber: '以下字段填写不是合法数字,请修正后再保存:{fields}', } }, @@ -2219,6 +2264,12 @@ export default { finalPricePreview: '最终单张价格预览', notConfigured: '未配置' }, + modelsList: { + title: '自定义 /v1/models 模型列表', + hint: '仅影响 /v1/models 展示结果,不影响白名单模型调用和账号调度。', + loading: '正在加载模型列表...', + empty: '暂无可展示模型' + }, claudeCode: { title: 'Claude Code 客户端限制', tooltip: @@ -2627,14 +2678,37 @@ export default { modelFilterIncludeSummary: '仅 {count} 个模型生效', modelFilterExcludeSummary: '排除 {count} 个模型', emptyLogs: '暂无审核记录', + preBlockSyncStatus: '前置拦截同步状态', + preBlockSyncHint: '同步审核链路的实时计数,不包含异步写记录任务。', + preBlockActive: '同步处理中', + preBlockActiveHint: '当前正在审核', + preBlockChecked: '已检查', + preBlockCheckedHint: '进入前置拦截链路', + preBlockAllowed: '已放行', + preBlockAllowedHint: '未触发拦截', + preBlockBlocked: '已拦截', + preBlockBlockedHint: '命中后拒绝请求', + preBlockErrors: '审核异常', + preBlockErrorsHint: '失败或无可用 Key', + preBlockAvgLatency: '平均耗时', + preBlockAvgLatencyHint: '同步链路平均值', + preBlockAPIKeyLoad: '审核 Key 负载', + preBlockAPIKeyLoadHint: '同步前置拦截直接轮询可用审核 Key。', + preBlockAPIKeyLoadSummary: '同步并发 {active} / 可用 Key {available},累计 {total} 次,worker:{workerActive} / {workerTotal}', + preBlockAPIKeyTotals: '累计 {total},成功 {success},异常 {errors}', + preBlockAPIKeyLoadEmpty: '暂无审核 Key 负载数据', + preBlockKeyActiveShort: '并发', + preBlockKeyTotalShort: '累计', + preBlockKeyAvgShort: '平均', + preBlockKeyLastShort: '最近', workerStatus: 'Worker 运行状态', - workerStatusHint: '异步观察任务的队列和 worker 池状态。', + workerStatusHint: '异步审计任务和前置拦截记录任务的队列与 Worker 池状态,不包含同步前置拦截审核请求。', workerPool: 'Worker 池', workerPoolMeta: '{active} 个处理中,{idle} 个空闲可用,共 {total} 个', queueUsage: '队列占用', activeWorkers: '处理中', idleWorkers: '空闲可用', - workerActive: '正在处理异步审计任务', + workerActive: '正在处理异步审计或记录任务', workerIdle: '已启动,当前空闲可用', workerDisabled: '风控或内容审计未启用', processed: '已处理', @@ -2643,11 +2717,17 @@ export default { lastCleanup: '上次清理:{time}', cleanupStats: '上次清理删除命中 {hit} 条,未命中 {nonHit} 条', riskSwitchOff: '系统开关关闭', + riskThresholds: '风险阈值', + riskThresholdsHint: '按 OpenAI Moderations 分类调整命中阈值,分数达到或超过阈值即视为命中。', + riskThresholdDefault: '默认 {value}', + riskThresholdReset: '恢复默认阈值', + riskThresholdPercent: '阈值百分比', tabs: { basic: '基础', scope: '审计范围', runtime: '运行队列', response: '命中通知', + riskThresholds: '风险阈值', keywords: '关键词拦截', retention: '日志保留', }, @@ -3063,6 +3143,7 @@ export default { usageWindows: '用量窗口', proxy: '代理', lastUsed: '最近使用', + createdAt: '创建时间', expiresAt: '过期时间', actions: '操作' }, @@ -3514,6 +3595,8 @@ export default { '启用后,上游 429/403/401 错误将自动重试而不标记账号限流或错误,适用于上游指向另一个 sub2api 实例的场景。', poolModeRetryCount: '同账号重试次数', poolModeRetryCountHint: '仅在池模式下生效。0 表示不原地重试;默认 {default},最大 {max}。', + poolModeRetryStatusCodes: '同账号重试状态码', + poolModeRetryStatusCodesHint: '仅在池模式下生效。以英文逗号分隔的 HTTP 状态码(100-599),命中时触发同账号重试。留空使用默认值({default})。', customErrorCodes: '自定义错误码', customErrorCodesHint: '仅对选中的错误码停止调度', customErrorCodesWarning: '仅选中的错误码会停止调度,其他错误将返回 500。', @@ -5653,7 +5736,17 @@ export default { defaultSubscriptionsEmpty: '未配置默认订阅。新用户不会自动获得订阅套餐。', defaultSubscriptionsDuplicate: '默认订阅存在重复分组:{groupId}。每个分组只能出现一次。', subscriptionGroup: '订阅分组', - subscriptionValidityDays: '有效期(天)' + subscriptionValidityDays: '有效期(天)', + defaultPlatformQuotas: '默认平台限额(注册时分配)', + defaultPlatformQuotasHint: '新用户注册时自动写入平台限额记录;已有用户不受影响。留空 = 该平台该窗口不限制。', + platformQuotaNotice: '月限额为 30 天滚动窗口,非自然月', + }, + platformQuota: { + platform: '平台', + daily: '日限额 (USD)', + weekly: '周限额 (USD)', + monthly: '月限额 (USD, 30天滚动)', + placeholder: '不限', }, claudeCode: { title: 'Claude Code 设置', @@ -6373,7 +6466,9 @@ export default { grantOnFirstBindHint: '已有账号首次绑定该来源时发放默认权益。', defaultSubscriptionsLabel: '默认订阅', defaultSubscriptionsHint: '仅对当前认证来源生效,未配置时不追加来源专属订阅。', - noSourceSubscriptions: '当前来源未配置专属默认订阅。' + noSourceSubscriptions: '当前来源未配置专属默认订阅。', + platformQuotasOverride: '平台限额覆盖', + platformQuotasOverrideHint: '留空的字段继承「系统默认平台限额」;填 0 表示禁止该窗口使用。', }, paymentVisibleMethods: { methodLabel: '{title} 可见方式', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 68162e53..535b151e 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -548,11 +548,17 @@ export interface AdminGroup extends Group { // OpenAI Messages 调度配置(仅 openai 平台使用) default_mapped_model?: string messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + models_list_config?: ModelsListConfig // 分组排序 sort_order: number } +export interface ModelsListConfig { + enabled: boolean + models: string[] +} + export interface ApiKey { id: number user_id: number @@ -632,6 +638,13 @@ export interface CreateGroupRequest { fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] + models_list_config?: ModelsListConfig + allow_messages_dispatch?: boolean + default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + model_routing?: Record | null + model_routing_enabled?: boolean + rpm_limit?: number require_oauth_only?: boolean require_privacy_set?: boolean // 从指定分组复制账号 @@ -660,6 +673,13 @@ export interface UpdateGroupRequest { fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] + models_list_config?: ModelsListConfig + allow_messages_dispatch?: boolean + default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig + model_routing?: Record | null + model_routing_enabled?: boolean + rpm_limit?: number require_oauth_only?: boolean require_privacy_set?: boolean copy_accounts_from_group_ids?: number[] @@ -2042,3 +2062,11 @@ export interface UpdateScheduledTestPlanRequest { // Payment types export type { SubscriptionPlan, PaymentOrder, CheckoutInfoResponse } from './payment' + +export type { + PlatformQuotaItem, + PlatformQuotaUpdateItem, + PlatformQuotaPlatform, + PlatformQuotaWindow, + PlatformQuotasResponse, +} from '@/api/admin/users' diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 5f2d27e8..92bb59b0 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -327,6 +327,9 @@ + @@ -648,6 +648,109 @@ +
+
+
+ +

+ {{ t("admin.groups.modelsList.hint") }} +

+
+ +
+
+
+ + 已选 {{ createModelsListSelectedCount }} / + {{ createModelsListState.items.length }} + +
+ + +
+
+
+

+ {{ t("admin.groups.modelsList.loading") }} +

+

+ {{ t("admin.groups.modelsList.empty") }} +

+
+ + + {{ item.id }} + + + +
+
+
+
+
+
+
+ +

+ {{ t("admin.groups.modelsList.hint") }} +

+
+ +
+
+
+ + 已选 {{ editModelsListSelectedCount }} / + {{ editModelsListState.items.length }} + +
+ + +
+
+
+

+ {{ t("admin.groups.modelsList.loading") }} +

+

+ {{ t("admin.groups.modelsList.empty") }} +

+
+ + + {{ item.id }} + + + +
+
+
+
+
(null); const sortableGroups = ref([]); const createMessagesDispatchDefaults = createDefaultMessagesDispatchFormState(); const editMessagesDispatchDefaults = createDefaultMessagesDispatchFormState(); +const createModelsListState = reactive(createInitialModelsListState()); +const editModelsListState = reactive(createInitialModelsListState()); +const createModelsListLoading = ref(false); +const editModelsListLoading = ref(false); +const modelsListCandidatesTracker = createModelsListCandidatesTracker(); +const createModelsListSelectedCount = computed( + () => createModelsListState.items.filter((item) => item.selected).length, +); +const editModelsListSelectedCount = computed( + () => editModelsListState.items.filter((item) => item.selected).length, +); const createForm = reactive({ name: "", @@ -3343,6 +3569,52 @@ const removeEditRoutingRule = (rule: ModelRoutingRule) => { editModelRoutingRules.value.splice(index, 1); }; +const resetModelsListState = ( + state: typeof createModelsListState, + config?: Parameters[0], +) => { + const fresh = createInitialModelsListState(config); + state.enabled = fresh.enabled; + state.savedModels = fresh.savedModels; + state.items = fresh.items; +}; + +const loadModelsListCandidates = async ( + mode: "create" | "edit", + groupID: number, + platform: GroupPlatform, +) => { + const request = { mode, groupID, platform }; + const requestID = modelsListCandidatesTracker.next(request); + const state = mode === "create" ? createModelsListState : editModelsListState; + const loadingRef = mode === "create" ? createModelsListLoading : editModelsListLoading; + loadingRef.value = true; + try { + const models = await adminAPI.groups.getModelsListCandidates(groupID, platform); + if (!modelsListCandidatesTracker.isCurrent(requestID, request)) { + return; + } + setModelsListCandidates(state, models); + } catch (error) { + if (!modelsListCandidatesTracker.isCurrent(requestID, request)) { + return; + } + console.error("Error loading group models list candidates:", error); + } finally { + if (modelsListCandidatesTracker.isCurrent(requestID, request)) { + loadingRef.value = false; + } + } +}; + +const moveCreateModelsListItem = (fromIndex: number, toIndex: number) => { + moveModelsListItem(createModelsListState, fromIndex, toIndex); +}; + +const moveEditModelsListItem = (fromIndex: number, toIndex: number) => { + moveModelsListItem(editModelsListState, fromIndex, toIndex); +}; + // 将 UI 格式的路由规则转换为 API 格式 const convertRoutingRulesToApiFormat = ( rules: ModelRoutingRule[], @@ -3632,6 +3904,11 @@ const handleSort = (key: string, order: 'asc' | 'desc') => { loadGroups(); }; +const openCreateModal = () => { + showCreateModal.value = true; + loadModelsListCandidates("create", 0, createForm.platform); +}; + const closeCreateModal = () => { showCreateModal.value = false; createModelRoutingRules.value.forEach((rule) => { @@ -3662,6 +3939,8 @@ const closeCreateModal = () => { createForm.supported_model_scopes = ["claude", "gemini_text", "gemini_image"]; createForm.mcp_xml_inject = true; createForm.copy_accounts_from_group_ids = []; + createForm.rpm_limit = 0; + resetModelsListState(createModelsListState); createModelRoutingRules.value = []; }; @@ -3716,6 +3995,7 @@ const handleCreateGroup = async () => { model_routing: convertRoutingRulesToApiFormat( createModelRoutingRules.value, ), + models_list_config: buildModelsListConfig(createModelsListState), supported_model_scopes: normalizeSupportedModelScopesForPlatform( createForm.platform, createForm.supported_model_scopes, @@ -3802,10 +4082,12 @@ const handleEdit = async (group: AdminGroup) => { editForm.mcp_xml_inject = group.mcp_xml_inject ?? true; editForm.copy_accounts_from_group_ids = []; // 复制账号字段每次编辑时重置为空 editForm.rpm_limit = group.rpm_limit ?? 0; + resetModelsListState(editModelsListState, group.models_list_config); // 加载模型路由规则(异步加载账号名称) editModelRoutingRules.value = await convertApiFormatToRoutingRules( group.model_routing, ); + loadModelsListCandidates("edit", group.id, group.platform); showEditModal.value = true; }; @@ -3819,6 +4101,7 @@ const closeEditModal = () => { editModelRoutingRules.value = []; editForm.copy_accounts_from_group_ids = []; resetMessagesDispatchFormState(editForm); + resetModelsListState(editModelsListState); }; const handleUpdateGroup = async () => { @@ -3851,6 +4134,7 @@ const handleUpdateGroup = async () => { model_routing: convertRoutingRulesToApiFormat( editModelRoutingRules.value, ), + models_list_config: buildModelsListConfig(editModelsListState), supported_model_scopes: normalizeSupportedModelScopesForPlatform( editForm.platform, editForm.supported_model_scopes, @@ -3968,6 +4252,8 @@ watch( createForm.require_oauth_only = false; createForm.require_privacy_set = false; } + resetModelsListState(createModelsListState); + loadModelsListCandidates("create", 0, newVal); }, ); @@ -3984,6 +4270,10 @@ watch( editForm.require_oauth_only = false; editForm.require_privacy_set = false; } + if (editingGroup.value) { + resetModelsListState(editModelsListState, editForm.platform === editingGroup.value.platform ? editingGroup.value.models_list_config : undefined); + loadModelsListCandidates("edit", editingGroup.value.id, newVal); + } }, ); @@ -4057,6 +4347,7 @@ const saveSortOrder = async () => { onMounted(() => { loadGroups(); + loadModelsListCandidates("create", 0, createForm.platform); document.addEventListener("click", handleClickOutside); }); diff --git a/frontend/src/views/admin/RiskControlView.vue b/frontend/src/views/admin/RiskControlView.vue index 4d56b492..b6d62767 100644 --- a/frontend/src/views/admin/RiskControlView.vue +++ b/frontend/src/views/admin/RiskControlView.vue @@ -53,7 +53,105 @@
-
+
+
+
+
+

{{ t('admin.riskControl.preBlockSyncStatus') }}

+

{{ t('admin.riskControl.preBlockSyncHint') }}

+
+ + {{ modeLabel(status?.mode ?? configForm.mode) }} + +
+ +
+
+
+

{{ item.label }}

+

{{ item.value }}

+

{{ item.meta }}

+
+
+
+
+ +
+
+
+

{{ t('admin.riskControl.preBlockAPIKeyLoad') }}

+

+ {{ t('admin.riskControl.preBlockAPIKeyLoadHint') }} +

+
+ + {{ preBlockAPIKeyLoadSummaryText }} + +
+ +
+
+
+
+
+
+ #{{ item.index + 1 }} + {{ item.masked || '-' }} + +
+

+ {{ t('admin.riskControl.preBlockAPIKeyTotals', { total: formatNumber(item.total), success: formatNumber(item.success), errors: formatNumber(item.errors) }) }} +

+
+
+
+

{{ t('admin.riskControl.preBlockKeyActiveShort') }}

+

{{ formatNumber(item.active) }}

+
+
+

{{ t('admin.riskControl.preBlockKeyTotalShort') }}

+

{{ formatNumber(item.total) }}

+
+
+

{{ t('admin.riskControl.preBlockKeyAvgShort') }}

+

{{ formatNumber(item.avg_latency_ms) }} ms

+
+
+

{{ t('admin.riskControl.preBlockKeyLastShort') }}

+

{{ formatNumber(item.last_latency_ms) }} ms

+
+
+
+
+
+
+
+
+

+ {{ t('admin.riskControl.preBlockAPIKeyLoadEmpty') }} +

+
+
+
+ +

{{ t('admin.riskControl.workerStatus') }}

@@ -794,6 +892,63 @@
+
+
+
+

{{ t('admin.riskControl.riskThresholds') }}

+

{{ t('admin.riskControl.riskThresholdsHint') }}

+
+ +
+ +
+
+
+
+ +

+ {{ t('admin.riskControl.riskThresholdDefault', { value: formatThresholdPercent(row.defaultValue) }) }} +

+
+ + {{ formatThresholdPercent(row.value) }} + +
+
+ +
+ + % +
+
+
+
+
+
= { + harassment: 98, + 'harassment/threatening': 90, + hate: 65, + 'hate/threatening': 65, + illicit: 95, + 'illicit/violent': 95, + 'self-harm': 65, + 'self-harm/intent': 85, + 'self-harm/instructions': 65, + sexual: 65, + 'sexual/minors': 65, + violence: 95, + 'violence/graphic': 95, +} +const riskThresholdCategories = Object.keys(riskThresholdDefaults) const { t } = useI18n() const appStore = useAppStore() @@ -1054,6 +1231,7 @@ const configForm = reactive({ hit_retention_days: 180, non_hit_retention_days: 3, pre_hash_check_enabled: false, + thresholds: { ...riskThresholdDefaults } as Record, blocked_keywords_text: '', keyword_blocking_mode: 'keyword_and_api' as KeywordBlockingMode, model_filter_type: 'all' as ContentModerationModelFilterType, @@ -1081,6 +1259,7 @@ const settingsTabs = computed>(() => [ { id: 'scope', label: t('admin.riskControl.tabs.scope') }, { id: 'runtime', label: t('admin.riskControl.tabs.runtime') }, { id: 'response', label: t('admin.riskControl.tabs.response') }, + { id: 'riskThresholds', label: t('admin.riskControl.tabs.riskThresholds') }, { id: 'keywords', label: t('admin.riskControl.tabs.keywords') }, { id: 'retention', label: t('admin.riskControl.tabs.retention') }, ]) @@ -1373,6 +1552,14 @@ const moderationScoreRows = computed(() => { .sort((a, b) => b.score - a.score) }) +const riskThresholdRows = computed(() => ( + riskThresholdCategories.map((category) => ({ + category, + value: configForm.thresholds[category] ?? riskThresholdDefaults[category], + defaultValue: riskThresholdDefaults[category], + })) +)) + const inputDetailText = computed(() => { if (!inputDetailRow.value) return '-' return inputDetailRow.value.input_excerpt || inputDetailRow.value.error || '-' @@ -1384,6 +1571,81 @@ const queueUsageStyle = computed(() => ({ width: queueUsagePercent.value, })) +const runtimeMode = computed(() => status.value?.mode ?? configForm.mode) + +const showPreBlockRuntimeCard = computed(() => runtimeMode.value === 'pre_block') + +const showWorkerRuntimeCard = computed(() => runtimeMode.value === 'observe') + +const preBlockMetricItems = computed(() => [ + { + key: 'active', + label: t('admin.riskControl.preBlockActive'), + value: formatNumber(status.value?.pre_block_active ?? 0), + meta: t('admin.riskControl.preBlockActiveHint'), + class: 'bg-sky-50 dark:bg-sky-900/10', + valueClass: 'text-sky-700 dark:text-sky-300', + }, + { + key: 'checked', + label: t('admin.riskControl.preBlockChecked'), + value: formatNumber(status.value?.pre_block_checked ?? 0), + meta: t('admin.riskControl.preBlockCheckedHint'), + class: 'bg-gray-50 dark:bg-dark-700/50', + valueClass: 'text-gray-900 dark:text-white', + }, + { + key: 'allowed', + label: t('admin.riskControl.preBlockAllowed'), + value: formatNumber(status.value?.pre_block_allowed ?? 0), + meta: t('admin.riskControl.preBlockAllowedHint'), + class: 'bg-emerald-50 dark:bg-emerald-900/10', + valueClass: 'text-emerald-700 dark:text-emerald-300', + }, + { + key: 'blocked', + label: t('admin.riskControl.preBlockBlocked'), + value: formatNumber(status.value?.pre_block_blocked ?? 0), + meta: t('admin.riskControl.preBlockBlockedHint'), + class: 'bg-rose-50 dark:bg-rose-900/10', + valueClass: 'text-rose-700 dark:text-rose-300', + }, + { + key: 'errors', + label: t('admin.riskControl.preBlockErrors'), + value: formatNumber(status.value?.pre_block_errors ?? 0), + meta: t('admin.riskControl.preBlockErrorsHint'), + class: 'bg-amber-50 dark:bg-amber-900/10', + valueClass: 'text-amber-700 dark:text-amber-300', + }, + { + key: 'latency', + label: t('admin.riskControl.preBlockAvgLatency'), + value: `${formatNumber(status.value?.pre_block_avg_latency_ms ?? 0)} ms`, + meta: t('admin.riskControl.preBlockAvgLatencyHint'), + class: 'bg-violet-50 dark:bg-violet-900/10', + valueClass: 'text-violet-700 dark:text-violet-300', + }, +]) + +const preBlockAPIKeyLoads = computed(() => ( + [...(status.value?.pre_block_api_key_loads ?? [])].sort((a, b) => a.index - b.index) +)) + +const preBlockAPIKeyMaxTotal = computed(() => Math.max(1, ...preBlockAPIKeyLoads.value.map((item) => item.total || 0))) + +const preBlockAPIKeyLoadSummaryText = computed(() => t('admin.riskControl.preBlockAPIKeyLoadSummary', { + active: formatNumber(status.value?.pre_block_api_key_active ?? 0), + available: formatNumber(status.value?.pre_block_api_key_available_count ?? 0), + total: formatNumber(status.value?.pre_block_api_key_total_calls ?? 0), + workerActive: formatNumber(status.value?.active_workers ?? 0), + workerTotal: formatNumber(status.value?.worker_count ?? configForm.worker_count), +})) + +function preBlockAPIKeyLoadWidth(total: number): string { + return `${Math.min(100, Math.max(0, (total / preBlockAPIKeyMaxTotal.value) * 100)).toFixed(1)}%` +} + const workerSlots = computed(() => { const total = Math.max(0, status.value?.worker_count ?? configForm.worker_count) const active = Math.max(0, status.value?.active_workers ?? 0) @@ -1445,6 +1707,7 @@ function applyConfig(config: ContentModerationConfig) { configForm.hit_retention_days = config.hit_retention_days || 180 configForm.non_hit_retention_days = Math.min(Math.max(config.non_hit_retention_days || 3, 1), 3) configForm.pre_hash_check_enabled = config.pre_hash_check_enabled ?? false + configForm.thresholds = riskThresholdsFromConfig(config.thresholds) configForm.blocked_keywords_text = Array.isArray(config.blocked_keywords) ? config.blocked_keywords.join('\n') : '' configForm.keyword_blocking_mode = normalizeKeywordBlockingMode(config.keyword_blocking_mode) const modelFilter = normalizeModelFilter(config.model_filter) @@ -1524,6 +1787,7 @@ async function saveConfig() { hit_retention_days: Number(configForm.hit_retention_days) || 180, non_hit_retention_days: Math.min(Math.max(Number(configForm.non_hit_retention_days) || 3, 1), 3), pre_hash_check_enabled: configForm.pre_hash_check_enabled, + thresholds: buildRiskThresholdPayload(), blocked_keywords: blockedKeywordList.value, keyword_blocking_mode: configForm.keyword_blocking_mode, model_filter: modelFilterPayload, @@ -1988,6 +2252,41 @@ function buildModelFilterPayload(): ContentModerationModelFilter { } } +function riskThresholdsFromConfig(thresholds: Record | null | undefined): Record { + const out: Record = { ...riskThresholdDefaults } + for (const category of riskThresholdCategories) { + const value = thresholds?.[category] + if (Number.isFinite(value)) { + out[category] = clampPercent(Number(value) * 100) + } + } + return out +} + +function buildRiskThresholdPayload(): Record { + const payload: Record = {} + for (const category of riskThresholdCategories) { + payload[category] = Number((clampPercent(configForm.thresholds[category]) / 100).toFixed(4)) + } + return payload +} + +function resetRiskThresholds() { + configForm.thresholds = { ...riskThresholdDefaults } +} + +function clampPercent(value: unknown): number { + const numeric = Number(value) + if (!Number.isFinite(numeric)) { + return 0 + } + return Math.min(100, Math.max(0, numeric)) +} + +function formatThresholdPercent(value: number): string { + return `${clampPercent(value).toFixed(1)}%` +} + function parseBlockedKeywords(value: string): string[] { const seen = new Set() const out: string[] = [] diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 6e542e52..6f682a3c 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3262,6 +3262,71 @@
+ + +
+
+ +

+ {{ t("admin.settings.defaults.defaultPlatformQuotasHint") }} +

+

+ {{ t("admin.settings.defaults.platformQuotaNotice") }} +

+
+
+ + + + + + + + + + + + + + + + + +
{{ t("admin.settings.platformQuota.platform") }}{{ t("admin.settings.platformQuota.daily") }}{{ t("admin.settings.platformQuota.weekly") }}{{ t("admin.settings.platformQuota.monthly") }}
+ {{ p }} + + + + + + +
+
+
+
@@ -3535,6 +3600,68 @@ + + +
+
+ +

+ {{ t("admin.settings.authSourceDefaults.platformQuotasOverrideHint") }} +

+
+
+ + + + + + + + + + + + + + + + + +
{{ t("admin.settings.platformQuota.platform") }}{{ t("admin.settings.platformQuota.daily") }}{{ t("admin.settings.platformQuota.weekly") }}{{ t("admin.settings.platformQuota.monthly") }}
+ {{ p }} + + + + + + +
+
+
+ @@ -6530,6 +6657,8 @@ import { adminAPI } from "@/api"; import { appendAuthSourceDefaultsToUpdateRequest, buildAuthSourceDefaultsState, + normalizePlatformQuotasMap, + sanitizePlatformQuotasMap, defaultWeChatConnectScopesForMode, deriveWeChatConnectStoredMode, normalizeDefaultSubscriptionSettings, @@ -6541,6 +6670,7 @@ import type { SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, + DefaultPlatformQuotasMap, OpenAIFastPolicyRule, WeChatConnectMode, WebSearchEmulationConfig, @@ -6835,6 +6965,8 @@ type SettingsForm = Omit< google_oauth_client_secret: string; force_email_on_third_party_signup: boolean; openai_advanced_scheduler_enabled: boolean; + // 系统全局平台限额 map;form 内始终归一化为全 4 平台对象(模板非空绑定依赖此不变量) + default_platform_quotas: DefaultPlatformQuotasMap; }; const form = reactive({ @@ -6851,6 +6983,7 @@ const form = reactive({ login_agreement_updated_at: "2026-03-31", login_agreement_documents: defaultLoginAgreementDocuments(), default_balance: 0, + default_platform_quotas: normalizePlatformQuotasMap() as DefaultPlatformQuotasMap, affiliate_rebate_rate: 20, affiliate_rebate_freeze_hours: 0, affiliate_rebate_duration_days: 0, @@ -7660,6 +7793,7 @@ async function loadSettings() { })) : defaultLoginAgreementDocuments(); Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings)); + form.default_platform_quotas = normalizePlatformQuotasMap(settings.default_platform_quotas); form.backend_mode_enabled = settings.backend_mode_enabled; form.default_subscriptions = normalizeDefaultSubscriptionSettings( settings.default_subscriptions, @@ -8214,6 +8348,7 @@ async function saveSettings() { }; } + payload.default_platform_quotas = sanitizePlatformQuotasMap(form.default_platform_quotas); appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults); const updated = await adminAPI.settings.updateSettings(payload); @@ -8224,6 +8359,7 @@ async function saveSettings() { } } Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated)); + form.default_platform_quotas = normalizePlatformQuotasMap(updated.default_platform_quotas); registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( updated.registration_email_suffix_whitelist, diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index 512bae67..93c70c6b 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -420,6 +420,17 @@ + + @@ -673,6 +684,15 @@ {{ t('admin.users.withdraw') }} + + +