chore: merge upstream Wei-Shaw/sub2api v0.1.132
Conflicts resolved (preserving fork customizations): - config.go: keep NodeTLSProxy + add upstream OpenAIHTTP2 - gateway_service.go: NewGatewayService now takes both rpmTokenBucketSvc (local) and userPlatformQuotaRepo (upstream) - wire_gen.go: wire both new args into the call site - http_upstream.go: drop redundant settings re-assignment; keep proxy URL log redaction - http_upstream_test.go: adopt upstream's explicit-0-disables semantics; keep 600s default constant in nil-cfg fallback test - user_handler_test.go / gateway_record_usage_test.go: pick up new userPlatformQuotaRepo nil parameter Also updated test stubs (windsurf_google_login_test.go, windsurf_tier_access_service_test.go, gateway_models_test.go) for new SetModelRateLimit variadic signature and the extra NewGatewayService arg. Upstream highlights: OpenAI embeddings gateway, user x platform USD quota, content-moderation risk thresholds, OAuth 401 credentials no-overwrite fix, HTTP/2 OpenAI upstream config, pool retry status code configurability, long-context cache pricing multipliers.
This commit is contained in:
commit
f519a02ec9
16
README.md
16
README.md
@ -103,9 +103,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://pateway.ai/?ch=1tsfr51"><img src="assets/partners/logos/pateway.png" alt="pateway" width="150"></a></td>
|
||||
<td>Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
|
||||
Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
|
||||
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
|
||||
<td>Thanks to PatewayAI for sponsoring this project! <a href="https://pateway.ai/?ch=1tsfr51">PatewayAI</a> is a premium API relay built for heavy AI developers, offering the full Claude and Codex series sourced 100% from official providers, with transparent token-level billing. Enterprise plans include high concurrency, dedicated management, contracts, and invoicing. Register now to get $3 in trial credits, top-ups from 60% off, and referral bonuses up to $150.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
@ -120,6 +118,18 @@ Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to recei
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
|
||||
<td>Thanks to Unity2 for sponsoring this project! <a href="https://unity2.ai/register?source=sub2api">Unity2</a> is a high-performance AI model API relay for individuals, teams, and enterprises, handling 30B+ tokens/day with 5000 RPM concurrency. One API Key works across Claude Code, Codex, OpenAI models, IDE plugins, and Agent workflows, with balance billing, bundled subscriptions, enterprise invoicing, and 1-on-1 support. <a href="https://unity2.ai/register?source=sub2api">Register</a> to claim $2 in balance, plus $10 more by joining the official group — up to $12 in free credit.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
|
||||
<td>Thanks to Veilx for sponsoring this project! <a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN is purpose-built for large-scale AI API traffic, deeply optimized for relay services and call chains across OpenAI, Claude, Gemini, and scenarios like chat, image generation, embeddings, and streaming — delivering lower latency and higher stability under heavy concurrency. It also offers China three-network optimized return lines, making it ideal for global AI relay platforms, overseas AI SaaS, and cross-border high-concurrency deployments.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
12
README_CN.md
12
README_CN.md
@ -119,6 +119,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
|
||||
<td>感谢 Unity2 赞助本项目! <a href="https://unity2.ai/register?source=sub2api">Unity2</a> 是面向个人开发者、团队、企业的高性能 AI 模型 API 中转平台,长期服务国内头部企业,日均承载超 300 亿 token 调用,支持 5000 RPM 级高并发。一个 API Key 即可适配 Claude Code、Codex、OpenAI 模型、IDE 插件和 Agent 工作流等场景。具备企业级稳定供应能力,在高并发、持续调用和团队集中采购场景下依然保持低延迟、高可用。同时支持余额计费、组合订阅、首充优惠、企业开票、专属 1v1 对接,适合个人高频使用和企业长期接入。现在注册 Unity2.ai 可领取 $2 余额,加入官方群再送 $10 余额,合计最高可领 $12 免费额度,适合先体验后长期使用。<a href="https://unity2.ai/register?source=sub2api">注册链接</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
|
||||
<td>感谢 Veilx 赞助本项目! <a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN 专为超大规模 API 请求场景打造,针对 AI 中转站业务与 AI API 调用链路进行了深度优化,轻松应对高并发、高频请求与大流量传输,为开发者与企业提供更快、更稳、更低延迟的加速体验。无论是 OpenAI、Claude、Gemini 等 AI 接口中转,还是聊天、绘图、Embedding、流式输出等复杂场景,Veilx 都能显著提升响应速度与连接稳定性,有效降低网络波动带来的超时与失败问题。同时,Veilx 提供中国三网优化回国极速线路,大幅提升中国大陆地区访问海外 AI 服务的速度与稳定性,特别适合全球 AI 中转平台、海外 AI SaaS、跨境业务与高并发 API 系统部署。专为 AI API 而生,让你的 AI 中转服务更快、更稳、更省心。<a href="https://veilx.io/#/hello/SJRBRVDV">购买地址</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
12
README_JA.md
12
README_JA.md
@ -119,6 +119,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://unity2.ai/register?source=sub2api"><img src="assets/partners/logos/unity2.png" alt="unity2" width="150"></a></td>
|
||||
<td>Unity2 のご支援に感謝します!<a href="https://unity2.ai/register?source=sub2api">Unity2</a> は個人開発者、チーム、企業向けの高性能 AI モデル API 中継プラットフォームです。中国の大手企業に長期にわたりサービスを提供しており、1日あたり 300 億以上のトークン呼び出しを処理し、5000 RPM 級の高並列性をサポートします。1つの API キーで Claude Code、Codex、OpenAI モデル、IDE プラグイン、Agent ワークフローなど様々なシナリオに対応できます。エンタープライズグレードの安定供給能力を備え、高並列・継続的な呼び出し・チームの集中購入シーンでも低レイテンシと高可用性を維持します。残高課金、組み合わせサブスクリプション、初回チャージ特典、企業向け請求書発行、専属 1v1 サポートにも対応しており、個人の頻繁な利用にも企業の長期導入にも適しています。今 Unity2.ai に登録すると $2 の残高、公式グループに参加するとさらに $10 の残高がもらえ、合計最大 $12 の無料クレジットを獲得できます — 試用後に長期利用したい方に最適です。<a href="https://unity2.ai/register?source=sub2api">登録リンク</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://veilx.io/#/hello/SJRBRVDV"><img src="assets/partners/logos/veilx.png" alt="veilx" width="150"></a></td>
|
||||
<td>Veilx のご支援に感謝します!<a href="https://veilx.io/#/hello/SJRBRVDV">Veilx</a> CDN は超大規模 API リクエストシナリオ向けに設計されており、AI 中継サービスと AI API 呼び出しチェーンに対して深く最適化されています。高並列・高頻度リクエスト・大容量トラフィックに容易に対応し、開発者と企業により高速で安定した、低レイテンシの加速体験を提供します。OpenAI、Claude、Gemini などの AI インターフェース中継はもちろん、チャット、画像生成、Embedding、ストリーミング出力などの複雑なシナリオでも、Veilx は応答速度と接続安定性を大幅に向上させ、ネットワーク変動によるタイムアウトや失敗を効果的に削減します。さらに、Veilx は中国三大ネットワーク最適化の高速回線を提供しており、中国本土から海外 AI サービスへのアクセス速度と安定性を大幅に向上させます。グローバル AI 中継プラットフォーム、海外 AI SaaS、越境ビジネス、高並列 API システム展開に特に適しています。AI API のために生まれ、あなたの AI 中継サービスをより速く、より安定して、より安心に。<a href="https://veilx.io/#/hello/SJRBRVDV">購入リンク</a>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/unity2.png
Normal file
BIN
assets/partners/logos/unity2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 744 KiB |
BIN
assets/partners/logos/veilx.png
Normal file
BIN
assets/partners/logos/veilx.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 376 KiB |
@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@ -1 +1 @@
|
||||
0.1.130
|
||||
0.1.132
|
||||
|
||||
@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
||||
userRPMCache := repository.NewUserRPMCache(redisClient)
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||
userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client)
|
||||
serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository)
|
||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
|
||||
@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
@ -194,7 +196,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService, serviceUserPlatformQuotaRepository)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||
pricingSvc := service.NewPricingService(cfg, nil)
|
||||
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
|
||||
@ -48,6 +48,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
|
||||
stdsql "database/sql"
|
||||
@ -124,6 +125,8 @@ type Client struct {
|
||||
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||
UserAttributeValue *UserAttributeValueClient
|
||||
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
|
||||
UserPlatformQuota *UserPlatformQuotaClient
|
||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||
UserSubscription *UserSubscriptionClient
|
||||
}
|
||||
@ -170,6 +173,7 @@ func (c *Client) init() {
|
||||
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
|
||||
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
|
||||
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
|
||||
c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config)
|
||||
c.UserSubscription = NewUserSubscriptionClient(c.config)
|
||||
}
|
||||
|
||||
@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
|
||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||
}, nil
|
||||
}
|
||||
@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
|
||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||
}, nil
|
||||
}
|
||||
@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
|
||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
c.UserPlatformQuota, c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
|
||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
c.UserPlatformQuota, c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.UserAttributeDefinition.mutate(ctx, m)
|
||||
case *UserAttributeValueMutation:
|
||||
return c.UserAttributeValue.mutate(ctx, m)
|
||||
case *UserPlatformQuotaMutation:
|
||||
return c.UserPlatformQuota.mutate(ctx, m)
|
||||
case *UserSubscriptionMutation:
|
||||
return c.UserSubscription.mutate(ctx, m)
|
||||
default:
|
||||
@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas queries the platform_quotas edge of a User.
|
||||
func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery {
|
||||
query := (&UserPlatformQuotaClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, id),
|
||||
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
|
||||
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: c.config}).Query()
|
||||
@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV
|
||||
}
|
||||
}
|
||||
|
||||
// UserPlatformQuotaClient is a client for the UserPlatformQuota schema.
|
||||
type UserPlatformQuotaClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config.
|
||||
func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient {
|
||||
return &UserPlatformQuotaClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`.
|
||||
func (c *UserPlatformQuotaClient) Use(hooks ...Hook) {
|
||||
c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`.
|
||||
func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a UserPlatformQuota entity.
|
||||
func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpCreate)
|
||||
return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities.
|
||||
func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk {
|
||||
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*UserPlatformQuotaCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdate)
|
||||
return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m))
|
||||
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id))
|
||||
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpDelete)
|
||||
return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne {
|
||||
builder := c.Delete().Where(userplatformquota.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &UserPlatformQuotaDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery {
|
||||
return &UserPlatformQuotaQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypeUserPlatformQuota},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a UserPlatformQuota entity by its id.
|
||||
func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) {
|
||||
return c.Query().Where(userplatformquota.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// QueryUser queries the user edge of a UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery {
|
||||
query := (&UserClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *UserPlatformQuotaClient) Hooks() []Hook {
|
||||
hooks := c.hooks.UserPlatformQuota
|
||||
return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...)
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *UserPlatformQuotaClient) Interceptors() []Interceptor {
|
||||
inters := c.inters.UserPlatformQuota
|
||||
return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...)
|
||||
}
|
||||
|
||||
func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// UserSubscriptionClient is a client for the UserSubscription schema.
|
||||
type UserSubscriptionClient struct {
|
||||
config
|
||||
@ -6025,7 +6200,8 @@ type (
|
||||
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
|
||||
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
|
||||
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
|
||||
UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
|
||||
@ -6035,7 +6211,8 @@ type (
|
||||
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
|
||||
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
|
||||
UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -139,6 +140,7 @@ func checkColumn(t, c string) error {
|
||||
userallowedgroup.Table: userallowedgroup.ValidColumn,
|
||||
userattributedefinition.Table: userattributedefinition.ValidColumn,
|
||||
userattributevalue.Table: userattributevalue.ValidColumn,
|
||||
userplatformquota.Table: userplatformquota.ValidColumn,
|
||||
usersubscription.Table: usersubscription.ValidColumn,
|
||||
})
|
||||
})
|
||||
|
||||
@ -85,6 +85,8 @@ type Group struct {
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
|
||||
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
|
||||
// 自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度
|
||||
ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config,omitempty"`
|
||||
// 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流
|
||||
RpmLimit int `json:"rpm_limit,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
@ -193,7 +195,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
|
||||
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig, group.FieldModelsListConfig:
|
||||
values[i] = new([]byte)
|
||||
case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
|
||||
values[i] = new(sql.NullBool)
|
||||
@ -440,6 +442,14 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
|
||||
}
|
||||
}
|
||||
case group.FieldModelsListConfig:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field models_list_config", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.ModelsListConfig); err != nil {
|
||||
return fmt.Errorf("unmarshal field models_list_config: %w", err)
|
||||
}
|
||||
}
|
||||
case group.FieldRpmLimit:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
|
||||
@ -641,6 +651,9 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("messages_dispatch_model_config=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("models_list_config=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ModelsListConfig))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rpm_limit=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
|
||||
builder.WriteByte(')')
|
||||
|
||||
@ -82,6 +82,8 @@ const (
|
||||
FieldDefaultMappedModel = "default_mapped_model"
|
||||
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
|
||||
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
|
||||
// FieldModelsListConfig holds the string denoting the models_list_config field in the database.
|
||||
FieldModelsListConfig = "models_list_config"
|
||||
// FieldRpmLimit holds the string denoting the rpm_limit field in the database.
|
||||
FieldRpmLimit = "rpm_limit"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
@ -192,6 +194,7 @@ var Columns = []string{
|
||||
FieldRequirePrivacySet,
|
||||
FieldDefaultMappedModel,
|
||||
FieldMessagesDispatchModelConfig,
|
||||
FieldModelsListConfig,
|
||||
FieldRpmLimit,
|
||||
}
|
||||
|
||||
@ -276,6 +279,8 @@ var (
|
||||
DefaultMappedModelValidator func(string) error
|
||||
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
|
||||
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
|
||||
// DefaultModelsListConfig holds the default value on creation for the "models_list_config" field.
|
||||
DefaultModelsListConfig domain.GroupModelsListConfig
|
||||
// DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
|
||||
DefaultRpmLimit int
|
||||
)
|
||||
|
||||
@ -467,6 +467,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (_c *GroupCreate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupCreate {
|
||||
_c.mutation.SetModelsListConfig(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetModelsListConfig(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
@ -698,6 +712,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultMessagesDispatchModelConfig
|
||||
_c.mutation.SetMessagesDispatchModelConfig(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ModelsListConfig(); !ok {
|
||||
v := group.DefaultModelsListConfig
|
||||
_c.mutation.SetModelsListConfig(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
v := group.DefaultRpmLimit
|
||||
_c.mutation.SetRpmLimit(v)
|
||||
@ -798,6 +816,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
|
||||
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ModelsListConfig(); !ok {
|
||||
return &ValidationError{Name: "models_list_config", err: errors.New(`ent: missing required field "Group.models_list_config"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RpmLimit(); !ok {
|
||||
return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
|
||||
}
|
||||
@ -960,6 +981,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
_node.MessagesDispatchModelConfig = value
|
||||
}
|
||||
if value, ok := _c.mutation.ModelsListConfig(); ok {
|
||||
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
|
||||
_node.ModelsListConfig = value
|
||||
}
|
||||
if value, ok := _c.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
_node.RpmLimit = value
|
||||
@ -1642,6 +1667,18 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (u *GroupUpsert) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsert {
|
||||
u.Set(group.FieldModelsListConfig, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateModelsListConfig() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldModelsListConfig)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
|
||||
u.Set(group.FieldRpmLimit, v)
|
||||
@ -2314,6 +2351,20 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (u *GroupUpsertOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelsListConfig(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateModelsListConfig() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelsListConfig()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
@ -3155,6 +3206,20 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (u *GroupUpsertBulk) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetModelsListConfig(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModelsListConfig sets the "models_list_config" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateModelsListConfig() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateModelsListConfig()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
|
||||
@ -616,6 +616,20 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (_u *GroupUpdate) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdate {
|
||||
_u.mutation.SetModelsListConfig(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetModelsListConfig(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
@ -1112,6 +1126,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelsListConfig(); ok {
|
||||
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
@ -2012,6 +2029,20 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetModelsListConfig sets the "models_list_config" field.
|
||||
func (_u *GroupUpdateOne) SetModelsListConfig(v domain.GroupModelsListConfig) *GroupUpdateOne {
|
||||
_u.mutation.SetModelsListConfig(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableModelsListConfig sets the "models_list_config" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableModelsListConfig(v *domain.GroupModelsListConfig) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetModelsListConfig(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRpmLimit sets the "rpm_limit" field.
|
||||
func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
|
||||
_u.mutation.ResetRpmLimit()
|
||||
@ -2538,6 +2569,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
|
||||
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.ModelsListConfig(); ok {
|
||||
_spec.SetField(group.FieldModelsListConfig, field.TypeJSON, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RpmLimit(); ok {
|
||||
_spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
|
||||
}
|
||||
|
||||
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary
|
||||
// function as UserPlatformQuota mutator.
|
||||
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m)
|
||||
}
|
||||
|
||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
|
||||
// function as UserSubscription mutator.
|
||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)
|
||||
|
||||
@ -42,6 +42,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
|
||||
}
|
||||
|
||||
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
|
||||
}
|
||||
|
||||
// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
|
||||
}
|
||||
|
||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
|
||||
|
||||
@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
|
||||
case *ent.UserAttributeValueQuery:
|
||||
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
|
||||
case *ent.UserPlatformQuotaQuery:
|
||||
return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil
|
||||
case *ent.UserSubscriptionQuery:
|
||||
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
|
||||
default:
|
||||
|
||||
@ -669,6 +669,7 @@ var (
|
||||
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
|
||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "models_list_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "rpm_limit", Type: field.TypeInt, Default: 0},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
@ -1612,6 +1613,53 @@ var (
|
||||
},
|
||||
},
|
||||
}
|
||||
// UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table.
|
||||
UserPlatformQuotasColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "platform", Type: field.TypeString, Size: 32},
|
||||
{Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
}
|
||||
// UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table.
|
||||
UserPlatformQuotasTable = &schema.Table{
|
||||
Name: "user_platform_quotas",
|
||||
Columns: UserPlatformQuotasColumns,
|
||||
PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]},
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "user_platform_quotas_users_platform_quotas",
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "userplatformquota_user_id_platform",
|
||||
Unique: true,
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]},
|
||||
Annotation: &entsql.IndexAnnotation{
|
||||
Where: "deleted_at IS NULL",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "userplatformquota_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
|
||||
UserSubscriptionsColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@ -1736,6 +1784,7 @@ var (
|
||||
UserAllowedGroupsTable,
|
||||
UserAttributeDefinitionsTable,
|
||||
UserAttributeValuesTable,
|
||||
UserPlatformQuotasTable,
|
||||
UserSubscriptionsTable,
|
||||
}
|
||||
)
|
||||
@ -1869,6 +1918,10 @@ func init() {
|
||||
UserAttributeValuesTable.Annotation = &entsql.Annotation{
|
||||
Table: "user_attribute_values",
|
||||
}
|
||||
UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable
|
||||
UserPlatformQuotasTable.Annotation = &entsql.Annotation{
|
||||
Table: "user_platform_quotas",
|
||||
}
|
||||
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
|
||||
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
|
||||
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector)
|
||||
// UserAttributeValue is the predicate function for userattributevalue builders.
|
||||
type UserAttributeValue func(*sql.Selector)
|
||||
|
||||
// UserPlatformQuota is the predicate function for userplatformquota builders.
|
||||
type UserPlatformQuota func(*sql.Selector)
|
||||
|
||||
// UserSubscription is the predicate function for usersubscription builders.
|
||||
type UserSubscription func(*sql.Selector)
|
||||
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
)
|
||||
@ -869,8 +870,12 @@ func init() {
|
||||
groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
|
||||
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
|
||||
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
|
||||
// groupDescModelsListConfig is the schema descriptor for models_list_config field.
|
||||
groupDescModelsListConfig := groupFields[30].Descriptor()
|
||||
// group.DefaultModelsListConfig holds the default value on creation for the models_list_config field.
|
||||
group.DefaultModelsListConfig = groupDescModelsListConfig.Default.(domain.GroupModelsListConfig)
|
||||
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
|
||||
groupDescRpmLimit := groupFields[30].Descriptor()
|
||||
groupDescRpmLimit := groupFields[31].Descriptor()
|
||||
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
|
||||
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
|
||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||
@ -1997,6 +2002,56 @@ func init() {
|
||||
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
|
||||
// userattributevalue.DefaultValue holds the default value on creation for the value field.
|
||||
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
|
||||
userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin()
|
||||
userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks()
|
||||
userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0]
|
||||
userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors()
|
||||
userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0]
|
||||
userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields()
|
||||
_ = userplatformquotaMixinFields0
|
||||
userplatformquotaFields := schema.UserPlatformQuota{}.Fields()
|
||||
_ = userplatformquotaFields
|
||||
// userplatformquotaDescCreatedAt is the schema descriptor for created_at field.
|
||||
userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor()
|
||||
// userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time)
|
||||
// userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor()
|
||||
// userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time)
|
||||
// userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
// userplatformquotaDescPlatform is the schema descriptor for platform field.
|
||||
userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor()
|
||||
// userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
|
||||
userplatformquota.PlatformValidator = func() func(string) error {
|
||||
validators := userplatformquotaDescPlatform.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
validators[2].(func(string) error),
|
||||
}
|
||||
return func(platform string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field.
|
||||
userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor()
|
||||
// userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field.
|
||||
userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64)
|
||||
// userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field.
|
||||
userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor()
|
||||
// userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field.
|
||||
userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64)
|
||||
// userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field.
|
||||
userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor()
|
||||
// userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field.
|
||||
userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64)
|
||||
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
|
||||
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
|
||||
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]
|
||||
|
||||
@ -155,6 +155,10 @@ func (Group) Fields() []ent.Field {
|
||||
Default(domain.OpenAIMessagesDispatchModelConfig{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
|
||||
field.JSON("models_list_config", domain.GroupModelsListConfig{}).
|
||||
Default(domain.GroupModelsListConfig{}).
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||
Comment("自定义 /v1/models 展示列表配置;仅影响模型列表响应,不影响调度"),
|
||||
|
||||
// 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。
|
||||
field.Int("rpm_limit").
|
||||
|
||||
@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge {
|
||||
edge.To("auth_identities", AuthIdentity.Type).
|
||||
Annotations(entsql.OnDelete(entsql.Cascade)),
|
||||
edge.To("pending_auth_sessions", PendingAuthSession.Type),
|
||||
edge.To("platform_quotas", UserPlatformQuota.Type),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
113
backend/ent/schema/user_platform_quota.go
Normal file
113
backend/ent/schema/user_platform_quota.go
Normal file
@ -0,0 +1,113 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/edge"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
)
|
||||
|
||||
// UserPlatformQuota holds the schema definition for per-user per-platform quota.
|
||||
type UserPlatformQuota struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "user_platform_quotas"},
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Mixin() []ent.Mixin {
|
||||
return []ent.Mixin{
|
||||
mixins.TimeMixin{},
|
||||
mixins.SoftDeleteMixin{},
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Int64("user_id"),
|
||||
field.String("platform").
|
||||
MaxLen(32).
|
||||
NotEmpty().
|
||||
Validate(func(s string) error {
|
||||
// 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms;
|
||||
// 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。
|
||||
switch s {
|
||||
case "anthropic", "openai", "gemini", "antigravity":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("platform %q is not allowed", s)
|
||||
}
|
||||
}),
|
||||
|
||||
// 日 / 周 / 月 USD 上限:
|
||||
// nil / not set → 无限额(完全放行)
|
||||
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
|
||||
// > 0 → USD 限额上限
|
||||
field.Float("daily_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("weekly_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("monthly_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
|
||||
// 当前窗口已用量(USD,preflight 时与 limit 比较)
|
||||
field.Float("daily_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("weekly_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("monthly_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
|
||||
// 窗口起点(NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底)
|
||||
field.Time("daily_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("weekly_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("monthly_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Edges() []ent.Edge {
|
||||
return []ent.Edge{
|
||||
edge.From("user", User.Type).
|
||||
Ref("platform_quotas").
|
||||
Field("user_id").
|
||||
Unique().
|
||||
Required(),
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// 软删除友好:只对未删记录唯一
|
||||
index.Fields("user_id", "platform").
|
||||
Unique().
|
||||
Annotations(entsql.IndexWhere("deleted_at IS NULL")),
|
||||
index.Fields("user_id"),
|
||||
}
|
||||
}
|
||||
@ -80,6 +80,8 @@ type Tx struct {
|
||||
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||
UserAttributeValue *UserAttributeValueClient
|
||||
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
|
||||
UserPlatformQuota *UserPlatformQuotaClient
|
||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||
UserSubscription *UserSubscriptionClient
|
||||
|
||||
@ -246,6 +248,7 @@ func (tx *Tx) init() {
|
||||
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
|
||||
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
|
||||
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
|
||||
tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config)
|
||||
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
|
||||
}
|
||||
|
||||
|
||||
@ -95,11 +95,13 @@ type UserEdges struct {
|
||||
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
|
||||
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
|
||||
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
|
||||
// PlatformQuotas holds the value of the platform_quotas edge.
|
||||
PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"`
|
||||
// UserAllowedGroups holds the value of the user_allowed_groups edge.
|
||||
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [13]bool
|
||||
loadedTypes [14]bool
|
||||
}
|
||||
|
||||
// APIKeysOrErr returns the APIKeys value or an error if the edge
|
||||
@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
|
||||
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
|
||||
}
|
||||
|
||||
// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) {
|
||||
if e.loadedTypes[12] {
|
||||
return e.PlatformQuotas, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "platform_quotas"}
|
||||
}
|
||||
|
||||
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
|
||||
if e.loadedTypes[12] {
|
||||
if e.loadedTypes[13] {
|
||||
return e.UserAllowedGroups, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user_allowed_groups"}
|
||||
@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
|
||||
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity.
|
||||
func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery {
|
||||
return NewUserClient(_m.config).QueryPlatformQuotas(_m)
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
|
||||
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
|
||||
|
||||
@ -85,6 +85,8 @@ const (
|
||||
EdgeAuthIdentities = "auth_identities"
|
||||
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
|
||||
EdgePendingAuthSessions = "pending_auth_sessions"
|
||||
// EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations.
|
||||
EdgePlatformQuotas = "platform_quotas"
|
||||
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
|
||||
EdgeUserAllowedGroups = "user_allowed_groups"
|
||||
// Table holds the table name of the user in the database.
|
||||
@ -171,6 +173,13 @@ const (
|
||||
PendingAuthSessionsInverseTable = "pending_auth_sessions"
|
||||
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
|
||||
PendingAuthSessionsColumn = "target_user_id"
|
||||
// PlatformQuotasTable is the table that holds the platform_quotas relation/edge.
|
||||
PlatformQuotasTable = "user_platform_quotas"
|
||||
// PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "userplatformquota" package.
|
||||
PlatformQuotasInverseTable = "user_platform_quotas"
|
||||
// PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge.
|
||||
PlatformQuotasColumn = "user_id"
|
||||
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
|
||||
UserAllowedGroupsTable = "user_allowed_groups"
|
||||
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
|
||||
@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti
|
||||
}
|
||||
}
|
||||
|
||||
// ByPlatformQuotasCount orders the results by platform_quotas count.
|
||||
func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByPlatformQuotas orders the results by platform_quotas terms.
|
||||
func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
|
||||
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step {
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
|
||||
)
|
||||
}
|
||||
func newPlatformQuotasStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(PlatformQuotasInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
|
||||
)
|
||||
}
|
||||
func newUserAllowedGroupsStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
|
||||
@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate
|
||||
})
|
||||
}
|
||||
|
||||
// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge.
|
||||
func HasPlatformQuotas() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates).
|
||||
func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := newPlatformQuotasStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
|
||||
func HasUserAllowedGroups() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea
|
||||
return _c.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _c.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_c *UserCreate) Mutation() *UserMutation {
|
||||
return _c.mutation
|
||||
@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
return _node, _spec
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -48,6 +49,7 @@ type UserQuery struct {
|
||||
withPaymentOrders *PaymentOrderQuery
|
||||
withAuthIdentities *AuthIdentityQuery
|
||||
withPendingAuthSessions *PendingAuthSessionQuery
|
||||
withPlatformQuotas *UserPlatformQuotaQuery
|
||||
withUserAllowedGroups *UserAllowedGroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas chains the current query on the "platform_quotas" edge.
|
||||
func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery {
|
||||
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, selector),
|
||||
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
|
||||
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: _q.config}).Query()
|
||||
@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery {
|
||||
withPaymentOrders: _q.withPaymentOrders.Clone(),
|
||||
withAuthIdentities: _q.withAuthIdentities.Clone(),
|
||||
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
|
||||
withPlatformQuotas: _q.withPlatformQuotas.Clone(),
|
||||
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery {
|
||||
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withPlatformQuotas = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
|
||||
@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
var (
|
||||
nodes = []*User{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [13]bool{
|
||||
loadedTypes = [14]bool{
|
||||
_q.withAPIKeys != nil,
|
||||
_q.withRedeemCodes != nil,
|
||||
_q.withSubscriptions != nil,
|
||||
@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
_q.withPaymentOrders != nil,
|
||||
_q.withAuthIdentities != nil,
|
||||
_q.withPendingAuthSessions != nil,
|
||||
_q.withPlatformQuotas != nil,
|
||||
_q.withUserAllowedGroups != nil,
|
||||
}
|
||||
)
|
||||
@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withPlatformQuotas; query != nil {
|
||||
if err := _q.loadPlatformQuotas(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} },
|
||||
func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withUserAllowedGroups; query != nil {
|
||||
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
|
||||
@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
}
|
||||
if len(query.ctx.Fields) > 0 {
|
||||
query.ctx.AppendFieldOnce(userplatformquota.FieldUserID)
|
||||
}
|
||||
query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.UserID
|
||||
node, ok := nodeids[fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda
|
||||
return _u.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdate) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU
|
||||
return _u.RemovePendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate {
|
||||
_u.mutation.ClearPlatformQuotas()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
|
||||
func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.RemovePlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
|
||||
func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{user.Label}
|
||||
@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU
|
||||
return _u.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdateOne) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us
|
||||
return _u.RemovePendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne {
|
||||
_u.mutation.ClearPlatformQuotas()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
|
||||
func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.RemovePlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
|
||||
func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserUpdate builder.
|
||||
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &User{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
|
||||
301
backend/ent/userplatformquota.go
Normal file
301
backend/ent/userplatformquota.go
Normal file
@ -0,0 +1,301 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuota is the model entity for the UserPlatformQuota schema.
|
||||
type UserPlatformQuota struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// DeletedAt holds the value of the "deleted_at" field.
|
||||
DeletedAt *time.Time `json:"deleted_at,omitempty"`
|
||||
// UserID holds the value of the "user_id" field.
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
// Platform holds the value of the "platform" field.
|
||||
Platform string `json:"platform,omitempty"`
|
||||
// DailyLimitUsd holds the value of the "daily_limit_usd" field.
|
||||
DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"`
|
||||
// WeeklyLimitUsd holds the value of the "weekly_limit_usd" field.
|
||||
WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
// MonthlyLimitUsd holds the value of the "monthly_limit_usd" field.
|
||||
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
// DailyUsageUsd holds the value of the "daily_usage_usd" field.
|
||||
DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"`
|
||||
// WeeklyUsageUsd holds the value of the "weekly_usage_usd" field.
|
||||
WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"`
|
||||
// MonthlyUsageUsd holds the value of the "monthly_usage_usd" field.
|
||||
MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"`
|
||||
// DailyWindowStart holds the value of the "daily_window_start" field.
|
||||
DailyWindowStart *time.Time `json:"daily_window_start,omitempty"`
|
||||
// WeeklyWindowStart holds the value of the "weekly_window_start" field.
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"`
|
||||
// MonthlyWindowStart holds the value of the "monthly_window_start" field.
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserPlatformQuotaQuery when eager-loading is set.
|
||||
Edges UserPlatformQuotaEdges `json:"edges"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph.
|
||||
type UserPlatformQuotaEdges struct {
|
||||
// User holds the value of the user edge.
|
||||
User *User `json:"user,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [1]bool
|
||||
}
|
||||
|
||||
// UserOrErr returns the User value or an error if the edge
|
||||
// was not loaded in eager-loading, or loaded but was not found.
|
||||
func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) {
|
||||
if e.User != nil {
|
||||
return e.User, nil
|
||||
} else if e.loadedTypes[0] {
|
||||
return nil, &NotFoundError{label: user.Label}
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user"}
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case userplatformquota.FieldID, userplatformquota.FieldUserID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case userplatformquota.FieldPlatform:
|
||||
values[i] = new(sql.NullString)
|
||||
case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the UserPlatformQuota fields.
|
||||
func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case userplatformquota.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case userplatformquota.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldDeletedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DeletedAt = new(time.Time)
|
||||
*_m.DeletedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldUserID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field user_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UserID = value.Int64
|
||||
}
|
||||
case userplatformquota.FieldPlatform:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field platform", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Platform = value.String
|
||||
}
|
||||
case userplatformquota.FieldDailyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyLimitUsd = new(float64)
|
||||
*_m.DailyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldWeeklyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyLimitUsd = new(float64)
|
||||
*_m.WeeklyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldMonthlyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyLimitUsd = new(float64)
|
||||
*_m.MonthlyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldDailyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldWeeklyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldMonthlyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldDailyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyWindowStart = new(time.Time)
|
||||
*_m.DailyWindowStart = value.Time
|
||||
}
|
||||
case userplatformquota.FieldWeeklyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyWindowStart = new(time.Time)
|
||||
*_m.WeeklyWindowStart = value.Time
|
||||
}
|
||||
case userplatformquota.FieldMonthlyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyWindowStart = new(time.Time)
|
||||
*_m.MonthlyWindowStart = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// QueryUser queries the "user" edge of the UserPlatformQuota entity.
|
||||
func (_m *UserPlatformQuota) QueryUser() *UserQuery {
|
||||
return NewUserPlatformQuotaClient(_m.config).QueryUser(_m)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this UserPlatformQuota.
|
||||
// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne {
|
||||
return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: UserPlatformQuota is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *UserPlatformQuota) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("UserPlatformQuota(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DeletedAt; v != nil {
|
||||
builder.WriteString("deleted_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("user_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("platform=")
|
||||
builder.WriteString(_m.Platform)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DailyLimitUsd; v != nil {
|
||||
builder.WriteString("daily_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.WeeklyLimitUsd; v != nil {
|
||||
builder.WriteString("weekly_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.MonthlyLimitUsd; v != nil {
|
||||
builder.WriteString("monthly_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("daily_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("weekly_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("monthly_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DailyWindowStart; v != nil {
|
||||
builder.WriteString("daily_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.WeeklyWindowStart; v != nil {
|
||||
builder.WriteString("weekly_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.MonthlyWindowStart; v != nil {
|
||||
builder.WriteString("monthly_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota.
|
||||
type UserPlatformQuotaSlice []*UserPlatformQuota
|
||||
202
backend/ent/userplatformquota/userplatformquota.go
Normal file
202
backend/ent/userplatformquota/userplatformquota.go
Normal file
@ -0,0 +1,202 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package userplatformquota
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the userplatformquota type in the database.
|
||||
Label = "user_platform_quota"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
|
||||
FieldDeletedAt = "deleted_at"
|
||||
// FieldUserID holds the string denoting the user_id field in the database.
|
||||
FieldUserID = "user_id"
|
||||
// FieldPlatform holds the string denoting the platform field in the database.
|
||||
FieldPlatform = "platform"
|
||||
// FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database.
|
||||
FieldDailyLimitUsd = "daily_limit_usd"
|
||||
// FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database.
|
||||
FieldWeeklyLimitUsd = "weekly_limit_usd"
|
||||
// FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database.
|
||||
FieldMonthlyLimitUsd = "monthly_limit_usd"
|
||||
// FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database.
|
||||
FieldDailyUsageUsd = "daily_usage_usd"
|
||||
// FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database.
|
||||
FieldWeeklyUsageUsd = "weekly_usage_usd"
|
||||
// FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database.
|
||||
FieldMonthlyUsageUsd = "monthly_usage_usd"
|
||||
// FieldDailyWindowStart holds the string denoting the daily_window_start field in the database.
|
||||
FieldDailyWindowStart = "daily_window_start"
|
||||
// FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database.
|
||||
FieldWeeklyWindowStart = "weekly_window_start"
|
||||
// FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database.
|
||||
FieldMonthlyWindowStart = "monthly_window_start"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
EdgeUser = "user"
|
||||
// Table holds the table name of the userplatformquota in the database.
|
||||
Table = "user_platform_quotas"
|
||||
// UserTable is the table that holds the user relation/edge.
|
||||
UserTable = "user_platform_quotas"
|
||||
// UserInverseTable is the table name for the User entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "user" package.
|
||||
UserInverseTable = "users"
|
||||
// UserColumn is the table column denoting the user relation/edge.
|
||||
UserColumn = "user_id"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for userplatformquota fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
FieldDeletedAt,
|
||||
FieldUserID,
|
||||
FieldPlatform,
|
||||
FieldDailyLimitUsd,
|
||||
FieldWeeklyLimitUsd,
|
||||
FieldMonthlyLimitUsd,
|
||||
FieldDailyUsageUsd,
|
||||
FieldWeeklyUsageUsd,
|
||||
FieldMonthlyUsageUsd,
|
||||
FieldDailyWindowStart,
|
||||
FieldWeeklyWindowStart,
|
||||
FieldMonthlyWindowStart,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Note that the variables below are initialized by the runtime
|
||||
// package on the initialization of the application. Therefore,
|
||||
// it should be imported in the main as follows:
|
||||
//
|
||||
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
var (
|
||||
Hooks [1]ent.Hook
|
||||
Interceptors [1]ent.Interceptor
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
|
||||
PlatformValidator func(string) error
|
||||
// DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field.
|
||||
DefaultDailyUsageUsd float64
|
||||
// DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field.
|
||||
DefaultWeeklyUsageUsd float64
|
||||
// DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field.
|
||||
DefaultMonthlyUsageUsd float64
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the UserPlatformQuota queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDeletedAt orders the results by the deleted_at field.
|
||||
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserID orders the results by the user_id field.
|
||||
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUserID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPlatform orders the results by the platform field.
|
||||
func ByPlatform(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPlatform, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyLimitUsd orders the results by the daily_limit_usd field.
|
||||
func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field.
|
||||
func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field.
|
||||
func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyUsageUsd orders the results by the daily_usage_usd field.
|
||||
func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field.
|
||||
func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field.
|
||||
func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyWindowStart orders the results by the daily_window_start field.
|
||||
func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyWindowStart orders the results by the weekly_window_start field.
|
||||
func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyWindowStart orders the results by the monthly_window_start field.
|
||||
func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserField orders the results by user field.
|
||||
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
|
||||
}
|
||||
}
|
||||
func newUserStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(UserInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
}
|
||||
799
backend/ent/userplatformquota/where.go
Normal file
799
backend/ent/userplatformquota/where.go
Normal file
@ -0,0 +1,799 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package userplatformquota
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
|
||||
func DeletedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
|
||||
func UserID(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ.
|
||||
func Platform(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ.
|
||||
func DailyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ.
|
||||
func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ.
|
||||
func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ.
|
||||
func DailyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ.
|
||||
func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ.
|
||||
func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ.
|
||||
func DailyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ.
|
||||
func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ.
|
||||
func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
|
||||
func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
|
||||
func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtIn applies the In predicate on the "deleted_at" field.
|
||||
func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...))
|
||||
}
|
||||
|
||||
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
|
||||
func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...))
|
||||
}
|
||||
|
||||
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
|
||||
func DeletedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
|
||||
func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
|
||||
func DeletedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
|
||||
func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
|
||||
func DeletedAtIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt))
|
||||
}
|
||||
|
||||
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
|
||||
func DeletedAtNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt))
|
||||
}
|
||||
|
||||
// UserIDEQ applies the EQ predicate on the "user_id" field.
|
||||
func UserIDEQ(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
|
||||
func UserIDNEQ(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDIn applies the In predicate on the "user_id" field.
|
||||
func UserIDIn(vs ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
|
||||
func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// PlatformEQ applies the EQ predicate on the "platform" field.
|
||||
func PlatformEQ(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformNEQ applies the NEQ predicate on the "platform" field.
|
||||
func PlatformNEQ(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformIn applies the In predicate on the "platform" field.
|
||||
func PlatformIn(vs ...string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...))
|
||||
}
|
||||
|
||||
// PlatformNotIn applies the NotIn predicate on the "platform" field.
|
||||
func PlatformNotIn(vs ...string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...))
|
||||
}
|
||||
|
||||
// PlatformGT applies the GT predicate on the "platform" field.
|
||||
func PlatformGT(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformGTE applies the GTE predicate on the "platform" field.
|
||||
func PlatformGTE(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformLT applies the LT predicate on the "platform" field.
|
||||
func PlatformLT(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformLTE applies the LTE predicate on the "platform" field.
|
||||
func PlatformLTE(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformContains applies the Contains predicate on the "platform" field.
|
||||
func PlatformContains(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field.
|
||||
func PlatformHasPrefix(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field.
|
||||
func PlatformHasSuffix(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformEqualFold applies the EqualFold predicate on the "platform" field.
|
||||
func PlatformEqualFold(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformContainsFold applies the ContainsFold predicate on the "platform" field.
|
||||
func PlatformContainsFold(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd))
|
||||
}
|
||||
|
||||
// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartIn applies the In predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart))
|
||||
}
|
||||
|
||||
// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart))
|
||||
}
|
||||
|
||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||
func HasUser() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
|
||||
func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
step := newUserStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.NotPredicates(p))
|
||||
}
|
||||
1513
backend/ent/userplatformquota_create.go
Normal file
1513
backend/ent/userplatformquota_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/userplatformquota_delete.go
Normal file
88
backend/ent/userplatformquota_delete.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity.
|
||||
type UserPlatformQuotaDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
|
||||
func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity.
|
||||
type UserPlatformQuotaDeleteOne struct {
|
||||
_d *UserPlatformQuotaDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
|
||||
func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
643
backend/ent/userplatformquota_query.go
Normal file
643
backend/ent/userplatformquota_query.go
Normal file
@ -0,0 +1,643 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities.
|
||||
type UserPlatformQuotaQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []userplatformquota.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.UserPlatformQuota
|
||||
withUser *UserQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the UserPlatformQuotaQuery builder.
|
||||
func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// QueryUser chains the current query on the "user" edge.
|
||||
func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// First returns the first UserPlatformQuota entity from the query.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota was found.
|
||||
func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{userplatformquota.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first UserPlatformQuota ID from the query.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota ID was found.
|
||||
func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one UserPlatformQuota entity is found.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota entities are found.
|
||||
func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{userplatformquota.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query.
|
||||
// Returns a *NotSingularError when more than one UserPlatformQuota ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
err = &NotSingularError{userplatformquota.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of UserPlatformQuotaSlice.
|
||||
func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]()
|
||||
return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of UserPlatformQuota IDs.
|
||||
func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserPlatformQuotaQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]userplatformquota.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...),
|
||||
withUser: _q.withUser.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// WithUser tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withUser = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UserPlatformQuota.Query().
|
||||
// GroupBy(userplatformquota.FieldCreatedAt).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &UserPlatformQuotaGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = userplatformquota.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UserPlatformQuota.Query().
|
||||
// Select(userplatformquota.FieldCreatedAt).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q}
|
||||
sbuild.label = userplatformquota.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations.
|
||||
func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !userplatformquota.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) {
|
||||
var (
|
||||
nodes = []*UserPlatformQuota{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [1]bool{
|
||||
_q.withUser != nil,
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*UserPlatformQuota).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &UserPlatformQuota{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
if query := _q.withUser; query != nil {
|
||||
if err := _q.loadUser(ctx, query, nodes, nil,
|
||||
func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error {
|
||||
ids := make([]int64, 0, len(nodes))
|
||||
nodeids := make(map[int64][]*UserPlatformQuota)
|
||||
for i := range nodes {
|
||||
fk := nodes[i].UserID
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
query.Where(user.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != userplatformquota.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
if _q.withUser != nil {
|
||||
_spec.Node.AddColumnOnce(userplatformquota.FieldUserID)
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(userplatformquota.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = userplatformquota.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities.
|
||||
type UserPlatformQuotaGroupBy struct {
|
||||
selector
|
||||
build *UserPlatformQuotaQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities.
|
||||
type UserPlatformQuotaSelect struct {
|
||||
*UserPlatformQuotaQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
985
backend/ent/userplatformquota_update.go
Normal file
985
backend/ent/userplatformquota_update.go
Normal file
@ -0,0 +1,985 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities.
|
||||
type UserPlatformQuotaUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
|
||||
func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedAt sets the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetDeletedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDeletedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDeletedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatform sets the "platform" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetPlatform(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePlatform sets the "platform" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetPlatform(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyLimitUsd sets the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetDailyLimitUsd()
|
||||
_u.mutation.SetDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDailyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetWeeklyLimitUsd()
|
||||
_u.mutation.SetWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearWeeklyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetMonthlyLimitUsd()
|
||||
_u.mutation.SetMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearMonthlyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyUsageUsd sets the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetDailyUsageUsd()
|
||||
_u.mutation.SetDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetWeeklyUsageUsd()
|
||||
_u.mutation.SetWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetMonthlyUsageUsd()
|
||||
_u.mutation.SetMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyWindowStart sets the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetDailyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDailyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyWindowStart sets the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetWeeklyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearWeeklyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyWindowStart sets the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetMonthlyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearMonthlyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the UserPlatformQuotaMutation object of the builder.
|
||||
func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UserPlatformQuotaUpdate) defaults() error {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
if userplatformquota.UpdateDefaultUpdatedAt == nil {
|
||||
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||
}
|
||||
v := userplatformquota.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UserPlatformQuotaUpdate) check() error {
|
||||
if v, ok := _u.mutation.Platform(); ok {
|
||||
if err := userplatformquota.PlatformValidator(v); err != nil {
|
||||
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DeletedAtCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Platform(); ok {
|
||||
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.DailyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.WeeklyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.MonthlyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DailyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.WeeklyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.MonthlyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity.
|
||||
type UserPlatformQuotaUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedAt sets the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetDeletedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDeletedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDeletedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatform sets the "platform" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetPlatform(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePlatform sets the "platform" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPlatform(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyLimitUsd sets the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetDailyLimitUsd()
|
||||
_u.mutation.SetDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDailyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetWeeklyLimitUsd()
|
||||
_u.mutation.SetWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearWeeklyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetMonthlyLimitUsd()
|
||||
_u.mutation.SetMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearMonthlyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyUsageUsd sets the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetDailyUsageUsd()
|
||||
_u.mutation.SetDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetWeeklyUsageUsd()
|
||||
_u.mutation.SetWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetMonthlyUsageUsd()
|
||||
_u.mutation.SetMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyWindowStart sets the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetDailyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDailyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyWindowStart sets the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetWeeklyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearWeeklyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyWindowStart sets the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetMonthlyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearMonthlyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the UserPlatformQuotaMutation object of the builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated UserPlatformQuota entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UserPlatformQuotaUpdateOne) defaults() error {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
if userplatformquota.UpdateDefaultUpdatedAt == nil {
|
||||
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||
}
|
||||
v := userplatformquota.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Platform(); ok {
|
||||
if err := userplatformquota.PlatformValidator(v); err != nil {
|
||||
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
|
||||
for _, f := range fields {
|
||||
if !userplatformquota.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != userplatformquota.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DeletedAtCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Platform(); ok {
|
||||
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.DailyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.WeeklyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.MonthlyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DailyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.WeeklyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.MonthlyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &UserPlatformQuota{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@ -6,12 +6,14 @@ require (
|
||||
connectrpc.com/connect v1.19.2
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/alicebob/miniredis/v2 v2.38.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/andybalholm/brotli v1.2.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
github.com/aws/smithy-go v1.24.2
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/dgraph-io/ristretto v0.2.0
|
||||
@ -74,7 +76,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
@ -160,6 +161,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
github.com/zclconf/go-cty v1.14.4 // indirect
|
||||
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
|
||||
|
||||
@ -18,6 +18,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l
|
||||
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
|
||||
github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw=
|
||||
github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
|
||||
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
@ -362,6 +364,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H
|
||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=
|
||||
|
||||
@ -651,6 +651,12 @@ type ProxyProbeConfig struct {
|
||||
|
||||
type BillingConfig struct {
|
||||
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
|
||||
// UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口。
|
||||
// 消费点:
|
||||
// - billing_cache_service.cacheWriteWorker 异步累加
|
||||
// - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载
|
||||
// 读写两端必须共用同一 TTL,避免缓存生命周期不一致导致 quota 计数漂移。
|
||||
UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"`
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
@ -688,6 +694,9 @@ type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
// 注意:这不影响流式数据传输,只控制等待响应头的时间
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// OpenAIResponseHeaderTimeout: OpenAI/Codex 上游等待响应头的超时时间(秒),0表示无超时
|
||||
// OpenAI/Codex 请求可能在上游排队较久;默认不使用通用响应头超时截断。
|
||||
OpenAIResponseHeaderTimeout int `mapstructure:"openai_response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||
@ -717,6 +726,8 @@ type GatewayConfig struct {
|
||||
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||
// NodeTLSProxy: Node.js TLS 代理配置
|
||||
NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"`
|
||||
// OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1)
|
||||
OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"`
|
||||
// ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
|
||||
ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
|
||||
|
||||
@ -815,6 +826,21 @@ type GatewayConfig struct {
|
||||
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
|
||||
}
|
||||
|
||||
// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。
|
||||
// 默认启用 HTTP/2;在部分代理不兼容时按策略回退 HTTP/1.1。
|
||||
type GatewayOpenAIHTTP2Config struct {
|
||||
// Enabled: 是否启用 OpenAI HTTP/2 优先策略
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// AllowProxyFallbackToHTTP1: HTTP/HTTPS 代理出现明确 H2 兼容错误时,临时回退 HTTP/1.1
|
||||
AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"`
|
||||
// FallbackErrorThreshold: 回退窗口内累计多少次兼容错误后触发回退
|
||||
FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"`
|
||||
// FallbackWindowSeconds: 统计兼容错误的时间窗口(秒)
|
||||
FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"`
|
||||
// FallbackTTLSeconds: 触发后回退 HTTP/1.1 的持续时间(秒)
|
||||
FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"`
|
||||
}
|
||||
|
||||
// UserMessageQueueConfig 用户消息串行队列配置
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
|
||||
type UserMessageQueueConfig struct {
|
||||
@ -1647,6 +1673,7 @@ func setDefaults() {
|
||||
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
|
||||
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
|
||||
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
|
||||
viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400)
|
||||
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
@ -1847,6 +1874,7 @@ func setDefaults() {
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.openai_response_header_timeout", 0)
|
||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||
@ -1902,6 +1930,12 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
|
||||
// OpenAI HTTP upstream protocol strategy
|
||||
viper.SetDefault("gateway.openai_http2.enabled", true)
|
||||
viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true)
|
||||
viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2)
|
||||
viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60)
|
||||
viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600)
|
||||
viper.SetDefault("gateway.image_concurrency.enabled", false)
|
||||
viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
|
||||
viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
|
||||
@ -2523,6 +2557,12 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.ResponseHeaderTimeout < 0 {
|
||||
return fmt.Errorf("gateway.response_header_timeout must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIResponseHeaderTimeout < 0 {
|
||||
return fmt.Errorf("gateway.openai_response_header_timeout must be non-negative")
|
||||
}
|
||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||
switch c.Gateway.ConnectionPoolIsolation {
|
||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||
@ -2697,6 +2737,15 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 {
|
||||
return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 {
|
||||
return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
||||
|
||||
@ -163,6 +163,41 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultOpenAIHTTP2Enabled(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
require.NoError(t, err)
|
||||
require.True(t, cfg.Gateway.OpenAIHTTP2.Enabled)
|
||||
require.True(t, cfg.Gateway.OpenAIHTTP2.AllowProxyFallbackToHTTP1)
|
||||
}
|
||||
|
||||
func TestLoadOpenAIHTTP2DisabledFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_OPENAI_HTTP2_ENABLED", "false")
|
||||
|
||||
cfg, err := Load()
|
||||
require.NoError(t, err)
|
||||
require.False(t, cfg.Gateway.OpenAIHTTP2.Enabled)
|
||||
}
|
||||
|
||||
func TestLoadDefaultOpenAIResponseHeaderTimeoutUnlimited(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, cfg.Gateway.OpenAIResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
func TestLoadOpenAIResponseHeaderTimeoutFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_OPENAI_RESPONSE_HEADER_TIMEOUT", "1800")
|
||||
|
||||
cfg, err := Load()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1800, cfg.Gateway.OpenAIResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
|
||||
@ -1220,6 +1255,16 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
|
||||
wantErr: "gateway.max_body_size",
|
||||
},
|
||||
{
|
||||
name: "gateway response header timeout",
|
||||
mutate: func(c *Config) { c.Gateway.ResponseHeaderTimeout = -1 },
|
||||
wantErr: "gateway.response_header_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway openai response header timeout",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIResponseHeaderTimeout = -1 },
|
||||
wantErr: "gateway.openai_response_header_timeout",
|
||||
},
|
||||
{
|
||||
name: "gateway max idle conns",
|
||||
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
|
||||
@ -1275,6 +1320,21 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
|
||||
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
|
||||
},
|
||||
{
|
||||
name: "gateway openai http2 fallback threshold",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackErrorThreshold = -1 },
|
||||
wantErr: "gateway.openai_http2.fallback_error_threshold",
|
||||
},
|
||||
{
|
||||
name: "gateway openai http2 fallback window",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackWindowSeconds = -1 },
|
||||
wantErr: "gateway.openai_http2.fallback_window_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway openai http2 fallback ttl",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIHTTP2.FallbackTTLSeconds = -1 },
|
||||
wantErr: "gateway.openai_http2.fallback_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
||||
|
||||
7
backend/internal/domain/models_list_config.go
Normal file
7
backend/internal/domain/models_list_config.go
Normal file
@ -0,0 +1,7 @@
|
||||
package domain
|
||||
|
||||
// GroupModelsListConfig controls the optional custom /v1/models response list.
|
||||
type GroupModelsListConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
}
|
||||
@ -982,6 +982,100 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
}
|
||||
|
||||
// ApplyOAuthCredentialsRequest is the payload for persisting re-authorized OAuth credentials.
|
||||
type ApplyOAuthCredentialsRequest struct {
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
// ApplyOAuthCredentials 将"重新授权"得到的新凭据原子落库。
|
||||
// POST /api/v1/admin/accounts/:id/apply-oauth-credentials
|
||||
//
|
||||
// 与通用 PUT /:id (Update) 接口的关键区别:
|
||||
// - 仅接收 type / credentials / extra 三个字段(不接受 concurrency / rpm / quota_* 等可能误传的字段)
|
||||
// - Extra 走 UpdateAccountExtra(JSONB key 级合并),**绝不**全量覆盖;
|
||||
// 避免 base_rpm / window_cost_limit / max_sessions / quota_* / privacy_mode
|
||||
// 等持久化配置在重新授权后丢失
|
||||
// - 内置 ClearError + InvalidateToken,避免前端额外两次调用,
|
||||
// 并修复旧路径未失效 token 缓存导致重新授权后立即 401 的隐性 bug
|
||||
//
|
||||
// 与 /refresh 的区别:/refresh 用现有 refresh_token 换 access_token(无用户交互),
|
||||
// 本接口承接前端完成完整 OAuth 流程后的落库步骤。
|
||||
func (h *AccountHandler) ApplyOAuthCredentials(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ApplyOAuthCredentialsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 预检查账号存在 + OAuth 类型(与 Refresh handler 语义一致,提供更友好的错误信息)。
|
||||
existing, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
if !existing.IsOAuth() {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("NOT_OAUTH", "cannot apply oauth credentials to non-OAuth account"))
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 增量合并 Extra(JSONB key 级 merge,绝不覆盖 base_rpm / window_cost_limit /
|
||||
// max_sessions / quota_* / privacy_mode 等持久化键)。
|
||||
// best-effort:失败仅记日志;下方 ClearAccountError 会从 DB 重新读取最新 account,
|
||||
// 因此响应里的 extra 始终以 DB 为准——这里不需要手动维护内存快照。
|
||||
if len(req.Extra) > 0 {
|
||||
if extraErr := h.adminService.UpdateAccountExtra(ctx, accountID, req.Extra); extraErr != nil {
|
||||
extraKeys := make([]string, 0, len(req.Extra))
|
||||
for k := range req.Extra {
|
||||
extraKeys = append(extraKeys, k)
|
||||
}
|
||||
slog.Error("apply_oauth_credentials.update_extra_failed",
|
||||
"account_id", accountID,
|
||||
"extra_keys", extraKeys,
|
||||
"err", extraErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if cleared, clearErr := h.adminService.ClearAccountError(ctx, accountID); clearErr != nil {
|
||||
slog.Warn("apply_oauth_credentials.clear_error_failed",
|
||||
"account_id", accountID,
|
||||
"err", clearErr,
|
||||
)
|
||||
} else if cleared != nil {
|
||||
updatedAccount = cleared
|
||||
}
|
||||
|
||||
if h.tokenCacheInvalidator != nil && updatedAccount.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(ctx, updatedAccount); invalidateErr != nil {
|
||||
slog.Warn("apply_oauth_credentials.invalidate_token_failed",
|
||||
"account_id", accountID,
|
||||
"err", invalidateErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(ctx, updatedAccount))
|
||||
}
|
||||
|
||||
// GetStats handles getting account statistics
|
||||
// GET /api/v1/admin/accounts/:id/stats
|
||||
func (h *AccountHandler) GetStats(c *gin.Context) {
|
||||
|
||||
52
backend/internal/handler/admin/account_handler_list_test.go
Normal file
52
backend/internal/handler/admin/account_handler_list_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAccountListRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts", handler.List)
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestAccountHandlerListIncludesCreatedAt(t *testing.T) {
|
||||
router, adminSvc := setupAccountListRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts?page=1&page_size=20&sort_by=created_at&sort_order=desc", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "created_at", adminSvc.lastListAccounts.sortBy)
|
||||
|
||||
var payload struct {
|
||||
Data struct {
|
||||
Items []struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
} `json:"items"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
|
||||
require.Len(t, payload.Data.Items, 1)
|
||||
|
||||
createdAt := payload.Data.Items[0].CreatedAt
|
||||
require.NotEmpty(t, createdAt)
|
||||
require.True(t, strings.HasSuffix(createdAt, "Z"), "created_at should be serialized as UTC")
|
||||
parsed, err := time.Parse(time.RFC3339Nano, createdAt)
|
||||
require.NoError(t, err)
|
||||
_, offset := parsed.Zone()
|
||||
require.Equal(t, 0, offset)
|
||||
}
|
||||
@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
userHandler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
@ -33,6 +33,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
|
||||
router.GET("/api/v1/admin/groups", groupHandler.List)
|
||||
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
|
||||
router.GET("/api/v1/admin/groups/:id/models-list-candidates", groupHandler.GetModelsListCandidates)
|
||||
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
|
||||
router.POST("/api/v1/admin/groups", groupHandler.Create)
|
||||
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
|
||||
@ -177,6 +178,12 @@ func TestGroupHandlerEndpoints(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/0/models-list-candidates?platform=openai", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "gpt-5.5")
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
|
||||
|
||||
@ -24,6 +24,7 @@ type stubAdminService struct {
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
getUserErr error
|
||||
createAccountErr error
|
||||
updateAccountErr error
|
||||
bulkUpdateAccountErr error
|
||||
@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
|
||||
if s.getUserErr != nil {
|
||||
return nil, s.getUserErr
|
||||
}
|
||||
for i := range s.users {
|
||||
if s.users[i].ID == id {
|
||||
return &s.users[i], nil
|
||||
@ -261,6 +265,13 @@ func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Gro
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupModelsListCandidates(ctx context.Context, id int64, platform string) ([]string, error) {
|
||||
if platform == service.PlatformOpenAI {
|
||||
return []string{"gpt-5.5", "gpt-5.4"}, nil
|
||||
}
|
||||
return []string{"claude-sonnet-4-6"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
@ -345,6 +356,10 @@ func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *s
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccountExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -34,6 +34,7 @@ type contentModerationConfigRequest struct {
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
Thresholds *map[string]float64 `json:"thresholds"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
@ -94,6 +95,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||
AllGroups: req.AllGroups,
|
||||
GroupIDs: req.GroupIDs,
|
||||
RecordNonHits: req.RecordNonHits,
|
||||
Thresholds: req.Thresholds,
|
||||
WorkerCount: req.WorkerCount,
|
||||
QueueSize: req.QueueSize,
|
||||
BlockStatus: req.BlockStatus,
|
||||
|
||||
@ -113,6 +113,7 @@ type CreateGroupRequest struct {
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
ModelsListConfig service.GroupModelsListConfig `json:"models_list_config"`
|
||||
// 分组 RPM 上限(0 = 不限制)
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
@ -153,6 +154,7 @@ type UpdateGroupRequest struct {
|
||||
RequirePrivacySet *bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
ModelsListConfig *service.GroupModelsListConfig `json:"models_list_config"`
|
||||
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
|
||||
RPMLimit *int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
@ -238,6 +240,28 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// GetModelsListCandidates handles getting candidate model IDs for custom /v1/models list.
|
||||
// GET /api/v1/admin/groups/:id/models-list-candidates
|
||||
func (h *GroupHandler) GetModelsListCandidates(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil || groupID < 0 {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.adminService.GetGroupModelsListCandidates(
|
||||
c.Request.Context(),
|
||||
groupID,
|
||||
c.Query("platform"),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
// POST /api/v1/admin/groups
|
||||
func (h *GroupHandler) Create(c *gin.Context) {
|
||||
@ -275,6 +299,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
ModelsListConfig: req.ModelsListConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
@ -330,6 +355,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
ModelsListConfig: req.ModelsListConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
|
||||
@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
// Default platform quotas(JSON map)
|
||||
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
|
||||
slog.Error("default_platform_quotas_get_failed", "error", err)
|
||||
} else {
|
||||
payload.DefaultPlatformQuotas = platformQuotas
|
||||
}
|
||||
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
@ -637,6 +644,18 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
|
||||
// 系统全局 platform quota 默认值(整体替换语义:nil = 不修改,non-nil = 整体覆盖)。
|
||||
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"`
|
||||
|
||||
// auth-source 层 platform quota 覆盖(override 语义:nil = 不修改,non-nil = 整体覆盖该 source 的 quota 配置)。
|
||||
AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"`
|
||||
AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"`
|
||||
AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"`
|
||||
AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"`
|
||||
AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"`
|
||||
AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"`
|
||||
AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
// 系统全局 platform quota 默认值(整体替换语义)
|
||||
DefaultPlatformQuotas: req.DefaultPlatformQuotas,
|
||||
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}(),
|
||||
}
|
||||
|
||||
// req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值);
|
||||
// non-nil(含 empty map)表示整体覆盖:empty map = 清空该 source 的所有 quota 配置。
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
Email: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
|
||||
@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas),
|
||||
},
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
|
||||
@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas),
|
||||
},
|
||||
OIDC: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
|
||||
@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas),
|
||||
},
|
||||
WeChat: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
|
||||
@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas),
|
||||
},
|
||||
GitHub: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
|
||||
@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas),
|
||||
},
|
||||
Google: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
|
||||
@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas),
|
||||
},
|
||||
DingTalk: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
|
||||
@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
// Default platform quotas(JSON map)—— 与 GetSettings 一致,避免保存后响应缺失该字段
|
||||
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
|
||||
slog.Error("default_platform_quotas_get_failed", "error", err)
|
||||
} else {
|
||||
payload.DefaultPlatformQuotas = platformQuotas
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.RiskControlEnabled != after.RiskControlEnabled {
|
||||
changed = append(changed, "risk_control_enabled")
|
||||
}
|
||||
// Default platform quotas(JSON map,整体比较)
|
||||
if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) {
|
||||
changed = append(changed, service.SettingKeyDefaultPlatformQuotas)
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
|
||||
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
|
||||
}
|
||||
// Platform quotas diff:整体替换语义,发单个 JSON key。
|
||||
if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) {
|
||||
changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name))
|
||||
}
|
||||
}
|
||||
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
|
||||
changed = append(changed, "force_email_on_third_party_signup")
|
||||
@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting,
|
||||
return result
|
||||
}
|
||||
|
||||
// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义:
|
||||
// nil = 请求未包含该字段(保留 fallback),non-nil(含 empty map)= 整体覆盖。
|
||||
// 注意:JSON null 与字段省略等价——两者均反序列化为 nil map,因此都保留旧值;
|
||||
// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。
|
||||
func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
|
||||
data := make(map[string]any)
|
||||
raw, err := json.Marshal(settings)
|
||||
@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
|
||||
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
|
||||
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
|
||||
data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas
|
||||
data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas
|
||||
data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas
|
||||
data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas
|
||||
data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas
|
||||
data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas
|
||||
data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas
|
||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||
|
||||
return data
|
||||
@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo)
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
|
||||
// equalNullableFloat compares two *float64 values treating nil as a distinct case.
|
||||
func equalNullableFloat(a, b *float64) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting.
|
||||
func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
switch win {
|
||||
case "daily":
|
||||
return s.DailyLimitUSD
|
||||
case "weekly":
|
||||
return s.WeeklyLimitUSD
|
||||
case "monthly":
|
||||
return s.MonthlyLimitUSD
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots.
|
||||
func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool {
|
||||
for _, platform := range service.AllowedQuotaPlatforms {
|
||||
b := before[platform]
|
||||
a := after[platform]
|
||||
if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) {
|
||||
return false
|
||||
}
|
||||
if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) {
|
||||
return false
|
||||
}
|
||||
if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@ -0,0 +1,188 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
after := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
},
|
||||
}
|
||||
|
||||
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
|
||||
found := false
|
||||
for _, key := range changed {
|
||||
if key == service.SettingKeyDefaultPlatformQuotas {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected change detection for default platform quotas, got %v", changed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffSettings_NoChangeWhenEqual(t *testing.T) {
|
||||
five := 5.0
|
||||
before := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
after := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
|
||||
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
|
||||
for _, key := range changed {
|
||||
if key == service.SettingKeyDefaultPlatformQuotas {
|
||||
t.Error("equal values should not be detected as changed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualNullableFloat(t *testing.T) {
|
||||
five := 5.0
|
||||
five2 := 5.0
|
||||
ten := 10.0
|
||||
cases := []struct {
|
||||
a, b *float64
|
||||
want bool
|
||||
}{
|
||||
{nil, nil, true},
|
||||
{&five, nil, false},
|
||||
{nil, &five, false},
|
||||
{&five, &five2, true},
|
||||
{&five, &ten, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := equalNullableFloat(c.a, c.b); got != c.want {
|
||||
t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
}
|
||||
after := map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
}
|
||||
if equalPlatformQuotaSettings(before, after) {
|
||||
t.Error("expected unequal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := &service.AuthSourceDefaultSettings{
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
},
|
||||
}
|
||||
after := &service.AuthSourceDefaultSettings{
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed := appendAuthSourceDefaultChanges([]string{}, before, after)
|
||||
// 改动 B5:整体替换语义,审计 log 发单个 JSON key,而非展开 84 个扁平 key。
|
||||
key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo")
|
||||
found := false
|
||||
for _, k := range changed {
|
||||
if k == key {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected %q in changed, got %v", key, changed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复:
|
||||
// PUT 发 auth_source_default_email_platform_quotas,GET 能读回相同值(端到端往返)。
|
||||
func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &settingHandlerRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
// PUT:发 email platform quota(openai monthly=20)
|
||||
putBody := map[string]any{
|
||||
"auth_source_default_email_platform_quotas": map[string]any{
|
||||
"openai": map[string]any{
|
||||
"monthly": 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
rawBody, err := json.Marshal(putBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
handler.UpdateSettings(c)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 验证 DB 中写入了 JSON key
|
||||
jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email")
|
||||
require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB")
|
||||
|
||||
// GET:验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20
|
||||
rec2 := httptest.NewRecorder()
|
||||
c2, _ := gin.CreateTestContext(rec2)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
|
||||
handler.GetSettings(c2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any)
|
||||
require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map")
|
||||
openaiPQ, ok := emailPQ["openai"].(map[string]any)
|
||||
require.True(t, ok, "expected openai entry in email platform quotas")
|
||||
monthly, ok := openaiPQ["monthly"].(float64)
|
||||
require.True(t, ok, "expected monthly to be float64")
|
||||
require.Equal(t, float64(20), monthly, "expected openai monthly=20")
|
||||
}
|
||||
@ -2,10 +2,15 @@ package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@ -20,15 +25,24 @@ type UserWithConcurrency struct {
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
adminService service.AdminService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view
|
||||
billingCache service.BillingCache // T17/T18 缓存失效(PUT/POST 路径)
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
|
||||
func NewUserHandler(
|
||||
adminService service.AdminService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
|
||||
billingCache service.BillingCache,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
concurrencyService: concurrencyService,
|
||||
adminService: adminService,
|
||||
concurrencyService: concurrencyService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
billingCache: billingCache,
|
||||
}
|
||||
}
|
||||
|
||||
@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, gin.H{"affected": affected})
|
||||
}
|
||||
|
||||
// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas
|
||||
// admin 视角:D14 lazy 归零 + 暴露 *_window_start 调试字段
|
||||
func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
userID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid user id")
|
||||
return
|
||||
}
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Success(c, map[string]any{"platform_quotas": []any{}})
|
||||
return
|
||||
}
|
||||
// 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。
|
||||
if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas.
|
||||
type UpdateUserPlatformQuotasRequest struct {
|
||||
Quotas []PlatformQuotaInput `json:"quotas" binding:"required"`
|
||||
}
|
||||
|
||||
// PlatformQuotaInput 单平台限额输入;limit 字段为 nil 表示不限制。
|
||||
type PlatformQuotaInput struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。
|
||||
|
||||
// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas
|
||||
// 全量替换该用户所有平台限额。
|
||||
func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) {
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Error(c, 503, "platform quota service not available")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserPlatformQuotasRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Quotas) > 4 {
|
||||
response.BadRequest(c, "quotas length must be <= 4")
|
||||
return
|
||||
}
|
||||
seen := make(map[string]struct{}, len(req.Quotas))
|
||||
for _, q := range req.Quotas {
|
||||
if !service.IsAllowedQuotaPlatform(q.Platform) {
|
||||
response.BadRequest(c, "invalid platform: "+q.Platform)
|
||||
return
|
||||
}
|
||||
if _, dup := seen[q.Platform]; dup {
|
||||
response.BadRequest(c, "duplicate platform: "+q.Platform)
|
||||
return
|
||||
}
|
||||
seen[q.Platform] = struct{}{}
|
||||
// daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义:
|
||||
// nil / not set → 无限额(完全放行)
|
||||
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
|
||||
// > 0 → USD 限额上限
|
||||
// 拦截 NaN / ±Inf:客户端可发送超大数(如 1e308 × 2)使 JSON 反序列化得到 +Inf,
|
||||
// 进入 DB 后 cache check 中 usage >= limit 永不成立,limit 等同失效。
|
||||
for _, f := range []struct {
|
||||
name string
|
||||
val *float64
|
||||
}{
|
||||
{"daily_limit_usd", q.DailyLimitUSD},
|
||||
{"weekly_limit_usd", q.WeeklyLimitUSD},
|
||||
{"monthly_limit_usd", q.MonthlyLimitUSD},
|
||||
} {
|
||||
if f.val == nil {
|
||||
continue
|
||||
}
|
||||
v := *f.val
|
||||
if v < 0 {
|
||||
response.BadRequest(c, f.name+" must be >= 0")
|
||||
return
|
||||
}
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
response.BadRequest(c, f.name+" must be a finite number")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas))
|
||||
for _, q := range req.Quotas {
|
||||
records = append(records, service.UserPlatformQuotaRecord{
|
||||
UserID: userID,
|
||||
Platform: q.Platform,
|
||||
DailyLimitUSD: q.DailyLimitUSD,
|
||||
WeeklyLimitUSD: q.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: q.MonthlyLimitUSD,
|
||||
})
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
// 校验用户是否存在,避免 FK 违反导致 500;用户不存在时返回 404。
|
||||
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。
|
||||
// ListByUser 失败不阻断主操作(best-effort),仅记录降级 warn。
|
||||
beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if beforeErr != nil {
|
||||
slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr)
|
||||
}
|
||||
if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords))
|
||||
for _, r := range beforeRecords {
|
||||
beforeByPlatform[r.Platform] = r
|
||||
}
|
||||
afterPlatforms := make(map[string]struct{}, len(records))
|
||||
for _, r := range records {
|
||||
afterPlatforms[r.Platform] = struct{}{}
|
||||
}
|
||||
changes := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
entry := map[string]any{
|
||||
"platform": r.Platform,
|
||||
"daily_limit_usd": r.DailyLimitUSD,
|
||||
"weekly_limit_usd": r.WeeklyLimitUSD,
|
||||
"monthly_limit_usd": r.MonthlyLimitUSD,
|
||||
}
|
||||
if prev, ok := beforeByPlatform[r.Platform]; ok {
|
||||
entry["before_daily_limit_usd"] = prev.DailyLimitUSD
|
||||
entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD
|
||||
entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD
|
||||
}
|
||||
changes = append(changes, entry)
|
||||
}
|
||||
// 补 removed 条目:before 存在但 after 缺失 = 该平台被软删除。
|
||||
// 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。
|
||||
for _, prev := range beforeRecords {
|
||||
if _, kept := afterPlatforms[prev.Platform]; kept {
|
||||
continue
|
||||
}
|
||||
changes = append(changes, map[string]any{
|
||||
"platform": prev.Platform,
|
||||
"removed": true,
|
||||
"before_daily_limit_usd": prev.DailyLimitUSD,
|
||||
"before_weekly_limit_usd": prev.WeeklyLimitUSD,
|
||||
"before_monthly_limit_usd": prev.MonthlyLimitUSD,
|
||||
})
|
||||
}
|
||||
// before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段;
|
||||
// false 时所有 entry 都会缺失 before_*_limit_usd,仅有 after 视图。
|
||||
slog.Info("admin.quota_updated",
|
||||
"actor_admin_id", getAdminIDFromContext(c),
|
||||
"target_user_id", userID,
|
||||
"platform_count", len(records),
|
||||
"before_snapshot_available", beforeErr == nil,
|
||||
"changes", changes)
|
||||
|
||||
// 失效 cache:对全部允许的 platform 统一 invalidate。
|
||||
// Trade-off:精确失效(仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser,
|
||||
// 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素,
|
||||
// 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。
|
||||
if h.billingCache != nil {
|
||||
for _, p := range service.AllowedQuotaPlatforms {
|
||||
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil {
|
||||
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回最新状态
|
||||
now := time.Now().UTC()
|
||||
records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]map[string]any, 0, len(records2))
|
||||
for i := range records2 {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset.
|
||||
type ResetUserPlatformQuotaWindowRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Window string `json:"window" binding:"required"`
|
||||
}
|
||||
|
||||
var allowedWindowsForQuotaReset = map[string]struct{}{
|
||||
"daily": {},
|
||||
"weekly": {},
|
||||
"monthly": {},
|
||||
}
|
||||
|
||||
// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset
|
||||
// 立即归零指定 (platform, window) 的用量并更新 window_start。
|
||||
func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) {
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Error(c, 503, "platform quota service not available")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ResetUserPlatformQuotaWindowRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsAllowedQuotaPlatform(req.Platform) {
|
||||
response.BadRequest(c, "invalid platform: "+req.Platform)
|
||||
return
|
||||
}
|
||||
if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok {
|
||||
response.BadRequest(c, "invalid window: "+req.Window)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
// 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。
|
||||
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil {
|
||||
if errors.Is(err, service.ErrUserPlatformQuotaNotFound) {
|
||||
response.NotFound(c, "user platform quota not found")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("admin.quota_window_reset",
|
||||
"actor_admin_id", getAdminIDFromContext(c),
|
||||
"target_user_id", userID,
|
||||
"platform", req.Platform,
|
||||
"window", req.Window)
|
||||
|
||||
if h.billingCache != nil {
|
||||
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil {
|
||||
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
|
||||
UpdatedAt: lastLoginAt,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(adminSvc, nil)
|
||||
handler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
|
||||
UpdatedAt: lastLoginAt,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(adminSvc, nil)
|
||||
handler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
301
backend/internal/handler/admin/user_platform_quota_admin_test.go
Normal file
301
backend/internal/handler/admin/user_platform_quota_admin_test.go
Normal file
@ -0,0 +1,301 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository,捕获 UpsertForUser 调用。
|
||||
type upsertCapturingQuotaRepo struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
listRecords []service.UserPlatformQuotaRecord
|
||||
listErr error
|
||||
upsertCalls []upsertCall
|
||||
upsertErr error
|
||||
resetCalls []resetCall
|
||||
resetErr error
|
||||
}
|
||||
|
||||
type upsertCall struct {
|
||||
userID int64
|
||||
records []service.UserPlatformQuotaRecord
|
||||
}
|
||||
type resetCall struct {
|
||||
userID int64
|
||||
platform string
|
||||
window string
|
||||
newStart time.Time
|
||||
}
|
||||
|
||||
func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return r.listRecords, r.listErr
|
||||
}
|
||||
func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
|
||||
cloned := make([]service.UserPlatformQuotaRecord, len(records))
|
||||
copy(cloned, records)
|
||||
r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned})
|
||||
return r.upsertErr
|
||||
}
|
||||
func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart})
|
||||
return r.resetErr
|
||||
}
|
||||
|
||||
// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。
|
||||
type billingCacheStub struct {
|
||||
service.BillingCache
|
||||
deleteCalls []deleteCall
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type deleteCall struct {
|
||||
userID int64
|
||||
platform string
|
||||
}
|
||||
|
||||
func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error {
|
||||
b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform})
|
||||
return b.deleteErr
|
||||
}
|
||||
|
||||
func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler {
|
||||
return &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: newStubAdminService(),
|
||||
}
|
||||
}
|
||||
|
||||
func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
c.Params = []gin.Param{{Key: "id", Value: "42"}}
|
||||
return c, w
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_Success(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0},
|
||||
{"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(repo.upsertCalls) != 1 {
|
||||
t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls))
|
||||
}
|
||||
if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 {
|
||||
t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0])
|
||||
}
|
||||
// 缓存失效:请求中 2 个 platform + 软删除的 2 个 platform(gemini, antigravity)= 4 次
|
||||
if len(cache.deleteCalls) != 4 {
|
||||
t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic","daily_limit_usd":1},
|
||||
{"platform":"anthropic","daily_limit_usd":2}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{
|
||||
listRecords: []service.UserPlatformQuotaRecord{
|
||||
{UserID: 42, Platform: "anthropic"},
|
||||
},
|
||||
}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if !strings.Contains(w.Body.String(), `"platform_quotas"`) {
|
||||
t.Errorf("response should contain platform_quotas array: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ───────── T4: Reset 测试 ─────────
|
||||
|
||||
func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
c.Params = []gin.Param{{Key: "id", Value: "42"}}
|
||||
return c, w
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_Success(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
body := `{"platform":"anthropic","window":"daily"}`
|
||||
c, w := postReq(t, body)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(repo.resetCalls) != 1 {
|
||||
t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls))
|
||||
}
|
||||
if repo.resetCalls[0].userID != 42 ||
|
||||
repo.resetCalls[0].platform != "anthropic" ||
|
||||
repo.resetCalls[0].window != "daily" {
|
||||
t.Errorf("unexpected reset call: %+v", repo.resetCalls[0])
|
||||
}
|
||||
if len(cache.deleteCalls) != 1 ||
|
||||
cache.deleteCalls[0].userID != 42 ||
|
||||
cache.deleteCalls[0].platform != "anthropic" {
|
||||
t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"unknown","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) {
|
||||
// handler 检查 service.ErrUserPlatformQuotaNotFound(由 adapter 包装而来)
|
||||
repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound}
|
||||
h := buildTestHandler(repo, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code < 500 {
|
||||
t.Errorf("expected 5xx, got %d", w.Code)
|
||||
}
|
||||
// 返回 JSON 错误响应
|
||||
var body2 map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil {
|
||||
t.Errorf("expected JSON error body, got: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = service.ErrUserNotFound
|
||||
h := &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: adminSvc,
|
||||
}
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = service.ErrUserNotFound
|
||||
h := &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: adminSvc,
|
||||
}
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,124 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type fakeQuotaRepoForAdmin struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
records []service.UserPlatformQuotaRecord
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return f.records, f.err
|
||||
}
|
||||
|
||||
func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context {
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request = req
|
||||
return c
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) {
|
||||
start := time.Now().Add(-1 * time.Hour)
|
||||
repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{
|
||||
UserID: 99, Platform: "anthropic",
|
||||
DailyUsageUSD: 1.0, DailyWindowStart: &start,
|
||||
}}}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "99"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), `"daily_window_start"`) {
|
||||
t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "abc"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code < 400 || w.Code >= 500 {
|
||||
t.Errorf("invalid id should yield 4xx, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) {
|
||||
repo := &fakeQuotaRepoForAdmin{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "99"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("empty list should be 200, got %d", w.Code)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
}
|
||||
data, ok := body["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing data object: %v", body)
|
||||
}
|
||||
quotas, ok := data["platform_quotas"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("data.platform_quotas missing or wrong type: %v", data)
|
||||
}
|
||||
if len(quotas) != 0 {
|
||||
t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "1"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("nil repo should return 200 empty, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404
|
||||
// (与 PUT / POST reset 端点行为一致;review fix:原实现返回空数组会让 admin 界面误判用户存在)
|
||||
func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||||
repo := &fakeQuotaRepoForAdmin{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "999"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
affiliateService,
|
||||
nil,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
|
||||
@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := &AuthHandler{authService: authService}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
|
||||
@ -147,6 +147,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
MCPXMLInject: g.MCPXMLInject,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||
ModelsListConfig: g.ModelsListConfig,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
AccountCount: g.AccountCount,
|
||||
ActiveAccountCount: g.ActiveAccountCount,
|
||||
|
||||
@ -3,6 +3,8 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// CustomMenuItem represents a user-configured custom menu entry.
|
||||
@ -246,6 +248,9 @@ type SystemSettings struct {
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
|
||||
// 系统全局默认平台配额(key = platform,nil/缺省 = 不限制)
|
||||
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
|
||||
@ -138,6 +138,7 @@ type AdminGroup struct {
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
ModelsListConfig domain.GroupModelsListConfig `json:"models_list_config"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointEmbeddings = "/v1/embeddings"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointImagesGenerations = "/v1/images/generations"
|
||||
EndpointImagesEdits = "/v1/images/edits"
|
||||
@ -42,6 +43,8 @@ const (
|
||||
func NormalizeInboundEndpoint(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
switch {
|
||||
case strings.Contains(path, EndpointEmbeddings):
|
||||
return EndpointEmbeddings
|
||||
case strings.Contains(path, EndpointChatCompletions):
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
@ -75,7 +78,7 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
||||
if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
||||
return inbound
|
||||
}
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
|
||||
@ -24,6 +24,7 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
// Direct canonical paths.
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/embeddings", EndpointEmbeddings},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1/images/generations", EndpointImagesGenerations},
|
||||
{"/v1/images/edits", EndpointImagesEdits},
|
||||
@ -77,6 +78,7 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai embeddings", EndpointEmbeddings, "/v1/embeddings", service.PlatformOpenAI, EndpointEmbeddings},
|
||||
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
|
||||
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -253,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -533,10 +534,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
@ -825,6 +828,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Beta policy block: return 400 immediately, no failover
|
||||
var betaBlockedErr *service.BetaBlockedError
|
||||
if errors.As(err, &betaBlockedErr) {
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message)
|
||||
return
|
||||
}
|
||||
@ -855,7 +859,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil {
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
@ -960,10 +964,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
@ -1015,22 +1021,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
|
||||
// Get available models from account configurations for the selected group platform.
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
|
||||
if apiKey != nil && apiKey.Group != nil && apiKey.Group.CustomModelsListEnabled() {
|
||||
availableModels = filterModelsByCustomList(availableModels, defaultModelIDsForPlatform(platform), apiKey.Group.ModelsListConfig.Models)
|
||||
writeCustomModelsList(c, platform, availableModels)
|
||||
return
|
||||
}
|
||||
|
||||
if len(availableModels) > 0 {
|
||||
// Build model list from whitelist
|
||||
models := make([]claude.Model, 0, len(availableModels))
|
||||
for _, modelID := range availableModels {
|
||||
models = append(models, claude.Model{
|
||||
ID: modelID,
|
||||
Type: "model",
|
||||
DisplayName: modelID,
|
||||
CreatedAt: "2024-01-01T00:00:00Z",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
writeModelsList(c, availableModels)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1057,6 +1055,134 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func writeModelsList(c *gin.Context, modelIDs []string) {
|
||||
models := make([]claude.Model, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
models = append(models, claude.Model{
|
||||
ID: modelID,
|
||||
Type: "model",
|
||||
DisplayName: modelID,
|
||||
CreatedAt: "2024-01-01T00:00:00Z",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func writeCustomModelsList(c *gin.Context, platform string, modelIDs []string) {
|
||||
if platform == service.PlatformOpenAI {
|
||||
writeOpenAIModelsList(c, modelIDs)
|
||||
return
|
||||
}
|
||||
writeModelsList(c, modelIDs)
|
||||
}
|
||||
|
||||
func writeOpenAIModelsList(c *gin.Context, modelIDs []string) {
|
||||
defaultsByID := make(map[string]openai.Model, len(openai.DefaultModels))
|
||||
for _, model := range openai.DefaultModels {
|
||||
defaultsByID[model.ID] = model
|
||||
}
|
||||
|
||||
models := make([]openai.Model, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
if model, ok := defaultsByID[modelID]; ok {
|
||||
models = append(models, model)
|
||||
continue
|
||||
}
|
||||
models = append(models, openai.Model{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: 1704067200,
|
||||
OwnedBy: "openai",
|
||||
Type: "model",
|
||||
DisplayName: modelID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func filterModelsByCustomList(availableModels, fallbackModels, selectedModels []string) []string {
|
||||
if len(selectedModels) == 0 {
|
||||
return availableModels
|
||||
}
|
||||
source := availableModels
|
||||
if len(source) == 0 {
|
||||
source = fallbackModels
|
||||
}
|
||||
if len(source) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
allowed := make([]string, 0, len(source))
|
||||
for _, model := range source {
|
||||
model = strings.TrimSpace(model)
|
||||
if model != "" {
|
||||
allowed = append(allowed, model)
|
||||
}
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(selectedModels))
|
||||
filtered := make([]string, 0, len(selectedModels))
|
||||
for _, model := range selectedModels {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
if !customModelsListAllowsModel(allowed, model) {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[model]; ok {
|
||||
continue
|
||||
}
|
||||
seen[model] = struct{}{}
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func customModelsListAllowsModel(availablePatterns []string, model string) bool {
|
||||
for _, pattern := range availablePatterns {
|
||||
if pattern == model {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(pattern, "*") && strings.HasPrefix(model, strings.TrimSuffix(pattern, "*")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultModelIDsForPlatform(platform string) []string {
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
return openai.DefaultModelIDs()
|
||||
case service.PlatformGemini:
|
||||
ids := make([]string, 0, len(geminicli.DefaultModels))
|
||||
for _, model := range geminicli.DefaultModels {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
return ids
|
||||
case service.PlatformAntigravity:
|
||||
models := antigravity.DefaultModels()
|
||||
ids := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
return ids
|
||||
default:
|
||||
ids := make([]string, 0, len(claude.DefaultModels))
|
||||
for _, model := range claude.DefaultModels {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
}
|
||||
|
||||
// AntigravityModels 返回 Antigravity 支持的全部模型
|
||||
// GET /antigravity/models
|
||||
func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
||||
@ -1502,6 +1628,14 @@ func (h *GatewayHandler) sendFailoverKeepalivePing(c *gin.Context, streamStarted
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// /v1/responses 的严格 SDK(Codex CLI)要求终止事件必须属于
|
||||
// response.completed/failed/incomplete/cancelled 集合。
|
||||
// Anthropic-backed Responses 路径同样会因为通用 error 帧被拒。
|
||||
if inboundIsResponses(c) {
|
||||
if writeResponsesFailedSSE(c, errType, message) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
@ -1520,10 +1654,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
// Writer 已被写过时(ping 已 flush)走 streamStarted 分支,
|
||||
// 让 handleStreamingAwareError 通过 SSE 发协议合规的终止事件,
|
||||
// 否则下游收到的就是 silent EOF。
|
||||
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
if c == nil || c.Writer == nil {
|
||||
return false
|
||||
}
|
||||
if c.Writer.Written() {
|
||||
streamStarted = true
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
@ -1650,7 +1790,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
@ -1898,6 +2038,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算
|
||||
// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。
|
||||
func extractQuotaResetSeconds(err error) int {
|
||||
const fallback = 60
|
||||
appErr := pkgerrors.FromError(err)
|
||||
if appErr == nil {
|
||||
return fallback
|
||||
}
|
||||
raw, ok := appErr.Metadata["window_resets_at"]
|
||||
if !ok || raw == "" {
|
||||
return fallback
|
||||
}
|
||||
resetAt, parseErr := time.Parse(time.RFC3339, raw)
|
||||
if parseErr != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.billing"),
|
||||
zap.String("raw", raw),
|
||||
zap.Error(parseErr),
|
||||
).Warn("quota.invalid_window_resets_at_format")
|
||||
return fallback
|
||||
}
|
||||
secs := time.Until(resetAt).Seconds()
|
||||
if secs <= 0 {
|
||||
// reset 时间已过:cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避,
|
||||
// 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。
|
||||
return fallback
|
||||
}
|
||||
return int(math.Ceil(secs))
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
|
||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||
msg := pkgerrors.Message(err)
|
||||
@ -1925,6 +2095,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte
|
||||
retrySeconds := 60 - int(time.Now().Unix()%60)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
|
||||
}
|
||||
if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) ||
|
||||
errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) ||
|
||||
errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) {
|
||||
// 与 RPM 超限一致映射 429 + Retry-After,让 SDK 自动退避(而非 403 直接失败)。
|
||||
// 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err)
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
logger.L().With(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
|
||||
require.Equal(t, "billing_error", code)
|
||||
require.NotEmpty(t, msg)
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) {
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339),
|
||||
})
|
||||
got := extractQuotaResetSeconds(err)
|
||||
if got < 10 || got > 11 {
|
||||
t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) {
|
||||
if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 {
|
||||
t.Errorf("T20: got %d, want 60 fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) {
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": "not-a-time",
|
||||
})
|
||||
if got := extractQuotaResetSeconds(err); got != 60 {
|
||||
t.Errorf("T21: got %d, want 60 fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) {
|
||||
// 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s:
|
||||
// 1 秒会导致客户端立即重试仍触发限额的退避循环;
|
||||
// 60s 让客户端按常规节奏退避,cache/DB 自愈期间不会反复打抖。
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339),
|
||||
})
|
||||
if got := extractQuotaResetSeconds(err); got != 60 {
|
||||
t.Errorf("T22: got %d, want 60 (fallback on past reset)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) {
|
||||
// quota 超限映射 429 + Retry-After(RFC 6585 / 与 RPM 一致),
|
||||
// 让 SDK(OpenAI 兼容客户端等)能按 Retry-After 自动退避。
|
||||
// 旧实现用 403 导致客户端不退避直接报错。
|
||||
// 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
{"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
{"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
status, code, _, retryAfter := billingErrorDetails(tc.err)
|
||||
if status != http.StatusTooManyRequests {
|
||||
t.Errorf("status = %d, want 429", status)
|
||||
}
|
||||
if code != "rate_limit_exceeded" {
|
||||
t.Errorf("code = %q, want rate_limit_exceeded", code)
|
||||
}
|
||||
if retryAfter < 3599 || retryAfter > 3601 {
|
||||
t.Errorf("retryAfter = %d, want ~3600", retryAfter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -33,7 +33,9 @@ func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testi
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
// Writer 已写后 ensureForwardErrorResponse 必须把错误以 SSE 形式追加,
|
||||
// 而不是 silent EOF。非 /responses 路径走 legacy data:{"type":"error"} 分支。
|
||||
func TestGatewayEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@ -43,7 +45,27 @@ func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *tes
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
assert.Contains(t, w.Body.String(), "already written")
|
||||
assert.Contains(t, w.Body.String(), `data: {"type":"error"`)
|
||||
}
|
||||
|
||||
// case B 回归:Anthropic-backed /responses,Writer 已被写过时
|
||||
// ensureForwardErrorResponse 仍要发 response.failed。
|
||||
func TestGatewayEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil)
|
||||
_, _ = c.Writer.WriteString(":\n\n")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, ":\n\n")
|
||||
assert.Contains(t, body, "event: response.failed\n")
|
||||
assert.Contains(t, body, `"type":"response.failed"`)
|
||||
}
|
||||
|
||||
@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
nil, // balanceNotifyService
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
@ -25,7 +25,11 @@ type gatewayModelsResponseForTest struct {
|
||||
}
|
||||
|
||||
type gatewayModelItemForTest struct {
|
||||
ID string `json:"id"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
@ -43,7 +47,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand
|
||||
gatewayService: service.NewGatewayService(
|
||||
repo,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
@ -127,6 +131,267 @@ func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
|
||||
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_CustomModelsListDisabledKeepsOriginalModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(22)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.5": "gpt-5.5",
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformOpenAI,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: false,
|
||||
Models: []string{"gpt-5.5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gpt-5.4", "gpt-5.5"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_CustomModelsListFiltersAndOrdersMappedModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(23)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
"gpt-5.5": "gpt-5.5",
|
||||
"legacy-gpt-2024": "legacy-gpt-2024",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformOpenAI,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: true,
|
||||
Models: []string{"gpt-5.5", "missing-model", "gpt-5.4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_CustomModelsListKeepsConcreteModelAllowedByWildcardMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(26)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformAnthropic,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: true,
|
||||
Models: []string{"claude-sonnet-4-6"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"claude-sonnet-4-6"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_CustomModelsListCanReturnEmptyWhenSelectionsUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(24)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformOpenAI,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: true,
|
||||
Models: []string{"gpt-5.5"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Empty(t, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_CustomModelsListFiltersDefaultFallbackModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(25)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{ID: 1, Platform: service.PlatformOpenAI},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformOpenAI,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: true,
|
||||
Models: []string{"gpt-5.5", "legacy-gpt-2024", "gpt-5.4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func TestGatewayModels_OpenAICustomModelsListKeepsOpenAIResponseShapeForDefaultFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(27)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{ID: 1, Platform: service.PlatformOpenAI},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{
|
||||
ID: groupID,
|
||||
Platform: service.PlatformOpenAI,
|
||||
ModelsListConfig: service.GroupModelsListConfig{
|
||||
Enabled: true,
|
||||
Models: []string{"gpt-5.5", "gpt-5.4"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gpt-5.5", "gpt-5.4"}, modelIDsForTest(got.Data))
|
||||
require.Equal(t, "model", got.Data[0].Object)
|
||||
require.NotZero(t, got.Data[0].Created)
|
||||
require.Equal(t, "openai", got.Data[0].OwnedBy)
|
||||
require.Empty(t, got.Data[0].CreatedAt)
|
||||
}
|
||||
|
||||
func modelIDsForTest(models []gatewayModelItemForTest) []string {
|
||||
ids := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
|
||||
@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil),
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
|
||||
@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
|
||||
253
backend/internal/handler/openai_embeddings.go
Normal file
253
backend/internal/handler/openai_embeddings.go
Normal file
@ -0,0 +1,253 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Embeddings handles the OpenAI-compatible Embeddings API.
|
||||
// POST /v1/embeddings
|
||||
func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) {
|
||||
streamStarted := false
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.embeddings",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || strings.TrimSpace(modelResult.String()) == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel))
|
||||
setOpsRequestContext(c, reqModel, false)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeSync))
|
||||
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, false, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai_embeddings.billing_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.errorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
switchCount := 0
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
if maxAccountSwitches <= 0 {
|
||||
maxAccountSwitches = 3
|
||||
}
|
||||
routingStart := time.Now()
|
||||
|
||||
for {
|
||||
selection, _, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
"",
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportHTTPSSE,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_embeddings.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, false)
|
||||
} else {
|
||||
h.errorResponse(c, http.StatusBadGateway, "api_error", "Upstream request failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
if account.Type != service.AccountTypeAPIKey {
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog)
|
||||
if !accountAcquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardEmbeddings(c.Request.Context(), c, account, forwardBody, "")
|
||||
}()
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, false)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_embeddings.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if c.Writer.Size() == writerSizeBeforeForward {
|
||||
h.errorResponse(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
|
||||
}
|
||||
reqLog.Warn("openai_embeddings.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.embeddings"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_embeddings.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_embeddings.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -1209,11 +1209,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
releaseAccountSlot := func() {
|
||||
if currentAccountRelease != nil {
|
||||
currentAccountRelease()
|
||||
currentAccountRelease = nil
|
||||
}
|
||||
}
|
||||
releaseTurnSlots := func() {
|
||||
releaseAccountSlot()
|
||||
if currentUserRelease != nil {
|
||||
currentUserRelease()
|
||||
currentUserRelease = nil
|
||||
@ -1233,9 +1236,26 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
ensureUserSlotHeld := func() bool {
|
||||
if currentUserRelease != nil {
|
||||
return true
|
||||
}
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_user_slot_reacquire_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||
return false
|
||||
}
|
||||
if !userAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||
return false
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
return true
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||
return
|
||||
@ -1246,195 +1266,244 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
firstMessage,
|
||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||
)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
for {
|
||||
reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
reqLog.Warn("openai.websocket_account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if lastFailoverErr != nil {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||
} else {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
}
|
||||
return
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
if selection == nil || selection.Account == nil {
|
||||
if lastFailoverErr != nil {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, lastFailoverErr)
|
||||
} else {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
}
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
InitialRequestModel: reqModel,
|
||||
BeforeRequest: func(turn int, payload []byte, originalModel string) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
|
||||
}
|
||||
model := strings.TrimSpace(originalModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
}
|
||||
if model == "" {
|
||||
model = reqModel
|
||||
}
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, model, payload); decision != nil && decision.Blocked {
|
||||
writeContentModerationWSError(ctx, wsConn, decision)
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, decision.Message, nil)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil {
|
||||
if result == nil || result.ImageCount <= 0 {
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.websocket_partial_error_with_image_result",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("image_count", result.ImageCount),
|
||||
zap.Error(turnErr),
|
||||
)
|
||||
}
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: inboundEndpoint,
|
||||
UpstreamEndpoint: upstreamEndpoint,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// 应用渠道模型映射到 WebSocket 首条消息
|
||||
wsFirstMessage := firstMessage
|
||||
if channelMappingWS.Mapped {
|
||||
wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
releaseAccountSlot()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
closeOpenAIWSFailoverExhausted(wsConn, failoverErr)
|
||||
return
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
reqLog.Warn("openai.websocket_upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if !ensureUserSlotHeld() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||
@ -1691,6 +1760,15 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
// /v1/responses 的严格 SDK(Codex CLI)要求终止事件必须属于
|
||||
// response.completed/failed/incomplete/cancelled 集合。
|
||||
// 通用 `event: error` 帧不被识别为终止事件,会导致
|
||||
// "stream closed before response.completed"。
|
||||
if inboundIsResponses(c) {
|
||||
if writeResponsesFailedSSE(c, errType, message) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
@ -1710,9 +1788,17 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
if c == nil || c.Writer == nil {
|
||||
return false
|
||||
}
|
||||
// 旧实现在 Writer.Written 时直接 return false,导致 ping 已 flush 之后的
|
||||
// 上游错误(http2 timeout、连接中断等)完全无法把错误传给客户端——
|
||||
// HTTP 200 已锁死,TCP 直接 EOF,Codex CLI 报 "stream closed before response.completed"。
|
||||
// 这里改成:Writer 已写过时强制走 streamStarted 分支,让
|
||||
// handleStreamingAwareError 通过 SSE 发协议合规的 response.failed。
|
||||
if c.Writer.Written() {
|
||||
streamStarted = true
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
@ -1783,6 +1869,23 @@ func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason s
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func closeOpenAIWSFailoverExhausted(conn *coderws.Conn, failoverErr *service.UpstreamFailoverError) {
|
||||
if failoverErr == nil {
|
||||
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
switch failoverErr.StatusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream rate limit exceeded, please retry later")
|
||||
case 529, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
closeOpenAIClientWS(conn, coderws.StatusTryAgainLater, "upstream service temporarily unavailable")
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
closeOpenAIClientWS(conn, coderws.StatusPolicyViolation, "upstream websocket authentication failed")
|
||||
default:
|
||||
closeOpenAIClientWS(conn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
}
|
||||
}
|
||||
|
||||
func writeContentModerationWSError(ctx context.Context, conn *coderws.Conn, decision *service.ContentModerationDecision) {
|
||||
if conn == nil || decision == nil {
|
||||
return
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -174,7 +175,11 @@ func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testin
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
// Writer 已写后 ensureForwardErrorResponse 必须仍然把错误信息以 SSE
|
||||
// 形式追加给客户端(streamStarted 强制 true)。
|
||||
// 这是 case B 修复:旧实现遇到 Writer.Written 直接 return false,
|
||||
// 客户端只能拿到 silent EOF;Codex CLI 报 "stream closed before response.completed"。
|
||||
func TestOpenAIEnsureForwardErrorResponse_AppendsSSEAfterWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@ -184,9 +189,34 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.True(t, wrote, "must attempt to communicate the failure to the client via SSE")
|
||||
// 状态码改不了(headers 已 flush),但 body 应该追加 SSE 错误事件。
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
assert.Contains(t, w.Body.String(), "already written")
|
||||
// 非 /responses 路径走 legacy event: error 分支。
|
||||
assert.Contains(t, w.Body.String(), "event: error\n")
|
||||
}
|
||||
|
||||
// case B 回归测试:/responses 路径,Writer 已被写过(模拟 ping flushed),
|
||||
// ensureForwardErrorResponse 必须发 response.failed,让 Codex 收到合规终止事件。
|
||||
func TestOpenAIEnsureForwardErrorResponse_ResponsesRouteAfterWrittenEmitsResponseFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, EndpointResponses, nil)
|
||||
// 模拟 ping 已 flush 的状态:Writer 已写过 1 个字节
|
||||
_, _ = c.Writer.WriteString(":\n\n")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, ":\n\n", "earlier ping bytes preserved")
|
||||
assert.Contains(t, body, "event: response.failed\n", "appended a Responses terminal event")
|
||||
assert.Contains(t, body, `"type":"response.failed"`)
|
||||
assert.Contains(t, body, `"code":"upstream_error"`)
|
||||
assert.Contains(t, body, "Upstream request failed")
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
|
||||
@ -266,7 +296,9 @@ func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
// Panic 在已 flush 的 /v1/responses 流中:状态码无法改(已 written),
|
||||
// 但 body 应追加 response.failed 让客户端识别为合规截断而不是 silent EOF。
|
||||
func TestOpenAIRecoverResponsesPanic_AppendsResponseFailedAfterWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@ -284,7 +316,9 @@ func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, "already written")
|
||||
assert.Contains(t, body, "event: response.failed\n")
|
||||
}
|
||||
|
||||
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
|
||||
@ -707,16 +741,31 @@ func (r *contentModerationHandlerSettingRepo) Delete(ctx context.Context, key st
|
||||
}
|
||||
|
||||
type contentModerationHandlerTestRepo struct {
|
||||
mu sync.Mutex
|
||||
logs []service.ContentModerationLog
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) CreateLog(ctx context.Context, log *service.ContentModerationLog) error {
|
||||
if log != nil {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.logs = append(r.logs, *log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) resetLogs() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.logs = nil
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) logSnapshot() []service.ContentModerationLog {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return append([]service.ContentModerationLog(nil), r.logs...)
|
||||
}
|
||||
|
||||
func (r *contentModerationHandlerTestRepo) ListLogs(ctx context.Context, filter service.ContentModerationLogFilter) ([]service.ContentModerationLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@ -775,7 +824,10 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
repo.logs = nil
|
||||
require.Eventually(t, func() bool {
|
||||
return len(repo.logSnapshot()) == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
repo.resetLogs()
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
@ -815,10 +867,11 @@ func TestOpenAIResponsesWebSocket_ContentModerationBlocksFirstFrame(t *testing.T
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, closeErr.Reason, "内容审计测试阻断")
|
||||
}
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.True(t, repo.logs[0].Flagged)
|
||||
require.Equal(t, service.ContentModerationActionBlock, repo.logs[0].Action)
|
||||
require.Equal(t, "bad prompt", repo.logs[0].InputExcerpt)
|
||||
logs := repo.logSnapshot()
|
||||
require.Len(t, logs, 1)
|
||||
require.True(t, logs[0].Flagged)
|
||||
require.Equal(t, service.ContentModerationActionBlock, logs[0].Action)
|
||||
require.Equal(t, "bad prompt", logs[0].InputExcerpt)
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PassthroughUsageLogPersistsUserAgentAndReasoningEffort(t *testing.T) {
|
||||
@ -1042,6 +1095,52 @@ func (s *openAIWSUsageHandlerAccountRepoStub) GetByID(ctx context.Context, id in
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
type openAIWSFailoverHandlerAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
accounts []service.Account
|
||||
rateLimitedIDs []int64
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
out := make([]service.Account, 0, len(s.accounts))
|
||||
for _, account := range s.accounts {
|
||||
if account.Platform == platform && account.IsSchedulable() {
|
||||
out = append(out, account)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return s.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
for _, account := range s.accounts {
|
||||
if account.ID == id {
|
||||
acc := account
|
||||
return &acc, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *openAIWSFailoverHandlerAccountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateLimitedIDs = append(s.rateLimitedIDs, id)
|
||||
for i := range s.accounts {
|
||||
if s.accounts[i].ID == id {
|
||||
reset := resetAt
|
||||
s.accounts[i].RateLimitResetAt = &reset
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type openAIWSUsageHandlerUsageLogRepoStub struct {
|
||||
service.UsageLogRepository
|
||||
created chan *service.UsageLog
|
||||
@ -1074,6 +1173,201 @@ func (s *openAIWSUsageHandlerChannelRepoStub) GetGroupPlatforms(ctx context.Cont
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_FailoverOnUpstreamUsageLimitEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
firstHitCh := make(chan []byte, 1)
|
||||
secondHitCh := make(chan []byte, 1)
|
||||
|
||||
firstUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
firstHitCh <- payload
|
||||
}
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"error","error":{"code":"rate_limit_exceeded","type":"usage_limit_reached","message":"The usage limit has been reached"}}`))
|
||||
cancelWrite()
|
||||
}))
|
||||
defer firstUpstream.Close()
|
||||
|
||||
secondUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{CompressionMode: coderws.CompressionContextTakeover})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.CloseNow() }()
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_, payload, readErr := conn.Read(readCtx)
|
||||
cancelRead()
|
||||
if readErr == nil {
|
||||
secondHitCh <- payload
|
||||
}
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
_ = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.completed","response":{"id":"resp_ws_failover_ok","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`))
|
||||
cancelWrite()
|
||||
_ = conn.Close(coderws.StatusNormalClosure, "done")
|
||||
}))
|
||||
defer secondUpstream.Close()
|
||||
|
||||
groupID := int64(4202)
|
||||
accounts := []service.Account{
|
||||
{
|
||||
ID: 9902,
|
||||
Name: "openai-ws-rate-limited",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-first",
|
||||
"base_url": firstUpstream.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 9903,
|
||||
Name: "openai-ws-healthy",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 2,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-second",
|
||||
"base_url": secondUpstream.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
"openai_apikey_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.RunMode = config.RunModeSimple
|
||||
cfg.Default.RateMultiplier = 1
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
cfg.Gateway.MaxAccountSwitches = 3
|
||||
|
||||
accountRepo := &openAIWSFailoverHandlerAccountRepoStub{accounts: accounts}
|
||||
rateLimitSvc := service.NewRateLimitService(accountRepo, nil, cfg, nil, nil)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
service.NewBillingService(cfg, nil),
|
||||
rateLimitSvc,
|
||||
billingCacheSvc,
|
||||
nil,
|
||||
&service.DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: gatewaySvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
maxAccountSwitches: 3,
|
||||
}
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1802,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1702, Status: service.StatusActive},
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformOpenAI, Status: service.StatusActive},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.User.ID, Concurrency: 1})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
handlerServer := httptest.NewServer(router)
|
||||
defer handlerServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(
|
||||
dialCtx,
|
||||
"ws"+strings.TrimPrefix(handlerServer.URL, "http")+"/openai/v1/responses",
|
||||
&coderws.DialOptions{CompressionMode: coderws.CompressionContextTakeover},
|
||||
)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = clientConn.CloseNow() }()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_, event, err := clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
|
||||
require.Equal(t, "resp_ws_failover_ok", gjson.GetBytes(event, "response.id").String())
|
||||
|
||||
select {
|
||||
case <-firstHitCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待第一个上游收到首帧超时")
|
||||
}
|
||||
select {
|
||||
case <-secondHitCh:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("等待第二个上游收到重放首帧超时")
|
||||
}
|
||||
require.Equal(t, []int64{int64(9902)}, accountRepo.rateLimitedIDs)
|
||||
}
|
||||
|
||||
func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSUsageLogCase) openAIResponsesWSUsageLogResult {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
@ -1168,7 +1462,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
@ -1190,6 +1484,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
|
||||
@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
|
||||
@ -53,6 +53,8 @@ const (
|
||||
opsCodeUserNotFound = "USER_NOT_FOUND"
|
||||
opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED"
|
||||
opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated"
|
||||
opsCodeGroupDeleted = "GROUP_DELETED"
|
||||
opsCodeGroupDisabled = "GROUP_DISABLED"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -1012,6 +1014,8 @@ func parseOpsErrorResponse(body []byte) parsedOpsError {
|
||||
var code string
|
||||
if v, ok := errObj["code"]; ok {
|
||||
switch n := v.(type) {
|
||||
case string:
|
||||
code = strings.TrimSpace(n)
|
||||
case float64:
|
||||
code = strconvItoa(int(n))
|
||||
case int:
|
||||
@ -1190,14 +1194,19 @@ func isOpsClientAuthError(code string, msg string) bool {
|
||||
opsCodeAPIKeyExpired,
|
||||
opsCodeAPIKeyDisabled,
|
||||
opsCodeUserNotFound,
|
||||
opsCodeUserInactive:
|
||||
opsCodeUserInactive,
|
||||
opsCodeGroupDeleted,
|
||||
opsCodeGroupDisabled:
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "invalid api key") ||
|
||||
strings.Contains(msg, "api key is required") ||
|
||||
strings.Contains(msg, "api key is disabled") ||
|
||||
strings.Contains(msg, "user associated with api key not found") ||
|
||||
strings.Contains(msg, "user account is not active")
|
||||
strings.Contains(msg, "user account is not active") ||
|
||||
strings.Contains(msg, "api key 所属分组已删除") ||
|
||||
strings.Contains(msg, "api key 所属分组已停用") ||
|
||||
strings.Contains(msg, "api key is not assigned to any group")
|
||||
}
|
||||
|
||||
func isOpsLocalBusinessLimitError(code string, msg string) bool {
|
||||
@ -1213,6 +1222,7 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool {
|
||||
return strings.Contains(msg, "api key in query parameter is deprecated") ||
|
||||
strings.Contains(msg, "query parameter api_key is deprecated") ||
|
||||
strings.Contains(msg, "no active subscription found for this group") ||
|
||||
strings.Contains(msg, "subscription is invalid or expired") ||
|
||||
strings.Contains(msg, opsErrInsufficientBalance) ||
|
||||
strings.Contains(msg, "insufficient account balance") ||
|
||||
strings.Contains(msg, "api key group platform is not gemini") ||
|
||||
@ -1223,7 +1233,22 @@ func isOpsLocalBusinessLimitError(code string, msg string) bool {
|
||||
strings.Contains(msg, "daily usage limit exceeded") ||
|
||||
strings.Contains(msg, "weekly usage limit exceeded") ||
|
||||
strings.Contains(msg, "monthly usage limit exceeded") ||
|
||||
strings.Contains(msg, "requests-per-minute limit exceeded")
|
||||
strings.Contains(msg, "usage quota exhausted for this platform") ||
|
||||
strings.Contains(msg, "requests-per-minute limit exceeded") ||
|
||||
strings.Contains(msg, "too many pending requests") ||
|
||||
strings.Contains(msg, "concurrency limit exceeded") ||
|
||||
strings.Contains(msg, "image generation concurrency limit exceeded") ||
|
||||
strings.Contains(msg, "this group is restricted to claude code clients") ||
|
||||
strings.Contains(msg, "this group does not allow /v1/messages dispatch") ||
|
||||
strings.Contains(msg, "image generation is not enabled for this group") ||
|
||||
strings.Contains(msg, "token counting is not supported for this platform") ||
|
||||
strings.Contains(msg, "images api is not supported for this platform") ||
|
||||
(strings.Contains(msg, "model ") && strings.Contains(msg, " not in whitelist")) ||
|
||||
(strings.Contains(msg, "beta feature ") && strings.Contains(msg, " is not allowed")) ||
|
||||
(strings.Contains(msg, "openai service_tier=") && strings.Contains(msg, " is not allowed for model")) ||
|
||||
strings.Contains(msg, "this account only allows codex official clients") ||
|
||||
strings.Contains(msg, "openai wsv1 is temporarily unsupported") ||
|
||||
strings.Contains(msg, "openai codex passthrough requires a non-empty instructions field")
|
||||
}
|
||||
|
||||
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
|
||||
|
||||
@ -288,6 +288,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
code: "USER_INACTIVE",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "deleted local API key group",
|
||||
errType: "api_error",
|
||||
message: "API Key 所属分组已删除",
|
||||
code: "GROUP_DELETED",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "disabled local API key group",
|
||||
errType: "api_error",
|
||||
message: "API Key 所属分组已停用",
|
||||
code: "GROUP_DISABLED",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "google deleted API key group message without semantic code",
|
||||
errType: "api_error",
|
||||
message: "API Key 所属分组已删除",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "anthropic unassigned API key group",
|
||||
errType: "permission_error",
|
||||
message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "google invalid API key",
|
||||
errType: "api_error",
|
||||
@ -389,6 +417,15 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "gateway subscription invalid cache recheck",
|
||||
errType: "billing_error",
|
||||
message: "subscription is invalid or expired",
|
||||
code: "billing_error",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "billing_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google insufficient account balance",
|
||||
errType: "api_error",
|
||||
@ -443,6 +480,132 @@ func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "user platform daily quota exhausted",
|
||||
errType: "api_error",
|
||||
message: "Daily usage quota exhausted for this platform.",
|
||||
code: "rate_limit_exceeded",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "local pending queue limit",
|
||||
errType: "rate_limit_error",
|
||||
message: "Too many pending requests, please retry later",
|
||||
code: "",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "rate_limit_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "local concurrency limit",
|
||||
errType: "rate_limit_error",
|
||||
message: "Concurrency limit exceeded for user, please retry later",
|
||||
code: "",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "rate_limit_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "group claude code only feature gate",
|
||||
errType: "permission_error",
|
||||
message: "This group is restricted to Claude Code clients (/v1/messages only)",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "group image generation feature gate",
|
||||
errType: "permission_error",
|
||||
message: "Image generation is not enabled for this group",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "route token counting platform unsupported",
|
||||
errType: "not_found_error",
|
||||
message: "Token counting is not supported for this platform",
|
||||
code: "",
|
||||
status: http.StatusNotFound,
|
||||
wantErrType: "not_found_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "route images API platform unsupported",
|
||||
errType: "not_found_error",
|
||||
message: "Images API is not supported for this platform",
|
||||
code: "",
|
||||
status: http.StatusNotFound,
|
||||
wantErrType: "not_found_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "antigravity model whitelist feature gate",
|
||||
errType: "permission_error",
|
||||
message: "model claude-3-5-sonnet not in whitelist",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google antigravity model whitelist feature gate",
|
||||
errType: "api_error",
|
||||
message: "model gemini-2.5-pro not in whitelist",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "claude beta policy block",
|
||||
errType: "invalid_request_error",
|
||||
message: "beta feature interleaved-thinking-2025-05-14 is not allowed",
|
||||
code: "",
|
||||
status: http.StatusBadRequest,
|
||||
wantErrType: "invalid_request_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "openai fast policy block",
|
||||
errType: "permission_error",
|
||||
message: "openai service_tier=priority is not allowed for model gpt-5.5",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "codex official client policy block",
|
||||
errType: "forbidden_error",
|
||||
message: "This account only allows Codex official clients",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "forbidden_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "openai wsv1 unsupported feature gate",
|
||||
errType: "invalid_request_error",
|
||||
message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
|
||||
code: "",
|
||||
status: http.StatusBadRequest,
|
||||
wantErrType: "invalid_request_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "openai passthrough instructions policy block",
|
||||
errType: "forbidden_error",
|
||||
message: "OpenAI codex passthrough requires a non-empty instructions field",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "forbidden_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -479,6 +642,22 @@ func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsClientBusinessLimitedMarkerExcludesCustomPolicyDenialFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied)
|
||||
|
||||
errType := normalizeOpsErrorType("invalid_request_error", "")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "custom admin policy message", "", http.StatusBadRequest)
|
||||
|
||||
require.Equal(t, "invalid_request_error", errType)
|
||||
require.Equal(t, "auth", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@ -583,6 +762,78 @@ func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
|
||||
code: "API_KEY_QUOTA_EXHAUSTED",
|
||||
status: http.StatusTooManyRequests,
|
||||
},
|
||||
{
|
||||
name: "provider deleted group shaped error",
|
||||
message: "API Key 所属分组已删除",
|
||||
code: "GROUP_DELETED",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider unassigned group shaped error",
|
||||
message: "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider local quota shaped error",
|
||||
message: "Daily usage quota exhausted for this platform.",
|
||||
code: "rate_limit_exceeded",
|
||||
status: http.StatusTooManyRequests,
|
||||
},
|
||||
{
|
||||
name: "provider feature gate shaped error",
|
||||
message: "Image generation is not enabled for this group",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider token counting unsupported shaped error",
|
||||
message: "Token counting is not supported for this platform",
|
||||
code: "404",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "provider image API unsupported shaped error",
|
||||
message: "Images API is not supported for this platform",
|
||||
code: "404",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "provider antigravity whitelist shaped error",
|
||||
message: "model claude-3-5-sonnet not in whitelist",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider beta policy shaped error",
|
||||
message: "beta feature interleaved-thinking-2025-05-14 is not allowed",
|
||||
code: "400",
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "provider openai fast policy shaped error",
|
||||
message: "openai service_tier=priority is not allowed for model gpt-5.5",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider codex client policy shaped error",
|
||||
message: "This account only allows Codex official clients",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider wsv1 unsupported shaped error",
|
||||
message: "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
|
||||
code: "400",
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "provider passthrough instructions shaped error",
|
||||
message: "OpenAI codex passthrough requires a non-empty instructions field",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -628,6 +879,14 @@ func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
}
|
||||
|
||||
func TestParseOpsErrorResponsePreservesNestedStringCode(t *testing.T) {
|
||||
parsed := parseOpsErrorResponse([]byte(`{"error":{"type":"permission_error","code":"GROUP_DELETED","message":"API Key 所属分组已删除"}}`))
|
||||
|
||||
require.Equal(t, "permission_error", parsed.ErrorType)
|
||||
require.Equal(t, "GROUP_DELETED", parsed.Code)
|
||||
require.Equal(t, "API Key 所属分组已删除", parsed.Message)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
104
backend/internal/handler/quotaview/helpers.go
Normal file
104
backend/internal/handler/quotaview/helpers.go
Normal file
@ -0,0 +1,104 @@
|
||||
// Package quotaview provides shared quota response helpers for user and admin handlers.
|
||||
// Extracted to avoid import cycles between handler and handler/admin packages.
|
||||
package quotaview
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB)。
|
||||
// includeWindowStart=true 时输出 *_window_start 字段(admin 视角调试用)
|
||||
func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any {
|
||||
daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart)
|
||||
weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart)
|
||||
monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart)
|
||||
out := map[string]any{
|
||||
"platform": r.Platform,
|
||||
"daily_usage_usd": daily.usage,
|
||||
"daily_limit_usd": daily.limit,
|
||||
"daily_window_resets_at": daily.resetsAt,
|
||||
"weekly_usage_usd": weekly.usage,
|
||||
"weekly_limit_usd": weekly.limit,
|
||||
"weekly_window_resets_at": weekly.resetsAt,
|
||||
"monthly_usage_usd": monthly.usage,
|
||||
"monthly_limit_usd": monthly.limit,
|
||||
"monthly_window_resets_at": monthly.resetsAt,
|
||||
}
|
||||
if includeWindowStart {
|
||||
out["daily_window_start"] = daily.windowStart
|
||||
out["weekly_window_start"] = weekly.windowStart
|
||||
out["monthly_window_start"] = monthly.windowStart
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type windowSlice struct {
|
||||
usage float64
|
||||
limit *float64
|
||||
resetsAt *string
|
||||
windowStart *string
|
||||
}
|
||||
|
||||
func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice {
|
||||
out := windowSlice{usage: usage, limit: limit}
|
||||
if expired {
|
||||
out.usage = 0
|
||||
out.resetsAt = nil
|
||||
} else if start != nil {
|
||||
s := nextReset.Format(time.RFC3339)
|
||||
out.resetsAt = &s
|
||||
}
|
||||
if includeStart && start != nil {
|
||||
s := start.Format(time.RFC3339)
|
||||
out.windowStart = &s
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NeedsDailyReset 判断日窗口是否已过期:start 早于「全局时区当天 0 点」即过期。
|
||||
// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。
|
||||
func NeedsDailyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return start.Before(timezone.StartOfDay(now))
|
||||
}
|
||||
|
||||
func NeedsWeeklyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return start.Before(timezone.StartOfWeek(now))
|
||||
}
|
||||
|
||||
// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。
|
||||
func NeedsMonthlyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return now.Sub(*start) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
func nextDailyResetTime(now time.Time) time.Time {
|
||||
return timezone.StartOfDay(now).AddDate(0, 0, 1)
|
||||
}
|
||||
|
||||
func nextWeeklyResetTime(now time.Time) time.Time {
|
||||
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
|
||||
}
|
||||
|
||||
// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。
|
||||
// 语义:
|
||||
// - start != nil → 返回 start + 30d(与 billing_cache_service.nextMonthlyResetFrom 一致)
|
||||
// - start == nil → 退化为 now + 30d(保留旧行为,避免 nil 崩溃)
|
||||
//
|
||||
// 导出(首字母大写)以允许测试直接调用。
|
||||
func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time {
|
||||
if start == nil {
|
||||
return now.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
return start.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
133
backend/internal/handler/quotaview/helpers_test.go
Normal file
133
backend/internal/handler/quotaview/helpers_test.go
Normal file
@ -0,0 +1,133 @@
|
||||
package quotaview
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TestNextMonthlyResetTimeFrom_FromStart 验证:start 已知时返回 start+30d,不随 now 漂移。
|
||||
func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) {
|
||||
t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
now := t0.Add(15 * 24 * time.Hour) // t0 + 15d
|
||||
want := t0.Add(30 * 24 * time.Hour) // t0 + 30d
|
||||
|
||||
got := NextMonthlyResetTimeFrom(&t0, now)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextMonthlyResetTimeFrom_NilStart 验证:start=nil 时退化为 now+30d(不 panic)。
|
||||
func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) {
|
||||
now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC)
|
||||
want := now.Add(30 * 24 * time.Hour)
|
||||
|
||||
got := NextMonthlyResetTimeFrom(nil, now)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证:
|
||||
// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record,
|
||||
// monthly_window_resets_at 始终等于 windowStart+30d,不随 now 漂移。
|
||||
func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) {
|
||||
windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339)
|
||||
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "openai",
|
||||
MonthlyUsageUSD: 5.0,
|
||||
MonthlyWindowStart: &windowStart,
|
||||
}
|
||||
|
||||
// 第一次调用:now = windowStart + 5d
|
||||
now1 := windowStart.Add(5 * 24 * time.Hour)
|
||||
out1 := LazyZeroQuotaForResponse(r, now1, false)
|
||||
resetsAt1, ok1 := out1["monthly_window_resets_at"]
|
||||
if !ok1 || resetsAt1 == nil {
|
||||
t.Fatal("first call: monthly_window_resets_at should be set for active window")
|
||||
}
|
||||
s1, ok := resetsAt1.(*string)
|
||||
if !ok || s1 == nil {
|
||||
t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1)
|
||||
}
|
||||
if *s1 != wantResetsAt {
|
||||
t.Errorf("first call: want %s, got %s", wantResetsAt, *s1)
|
||||
}
|
||||
|
||||
// 第二次调用:now = windowStart + 10d(不同 now,但 resetsAt 应不变)
|
||||
now2 := windowStart.Add(10 * 24 * time.Hour)
|
||||
out2 := LazyZeroQuotaForResponse(r, now2, false)
|
||||
resetsAt2, ok2 := out2["monthly_window_resets_at"]
|
||||
if !ok2 || resetsAt2 == nil {
|
||||
t.Fatal("second call: monthly_window_resets_at should be set for active window")
|
||||
}
|
||||
s2, ok := resetsAt2.(*string)
|
||||
if !ok || s2 == nil {
|
||||
t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2)
|
||||
}
|
||||
if *s2 != wantResetsAt {
|
||||
t.Errorf("second call: want %s, got %s", wantResetsAt, *s2)
|
||||
}
|
||||
|
||||
// 两次结果必须相等
|
||||
if *s1 != *s2 {
|
||||
t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。
|
||||
func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
// now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08(北京 5/26)
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC)
|
||||
|
||||
// start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08(北京 5/25)→ 应判定为过期
|
||||
startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC)
|
||||
if !NeedsDailyReset(&startPrevBeijingDay, now) {
|
||||
t.Error("上一个北京日的窗口应判定为过期")
|
||||
}
|
||||
|
||||
// start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08(北京 5/26 同日)→ 不应过期
|
||||
startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC)
|
||||
if NeedsDailyReset(&startSameBeijingDay, now) {
|
||||
t.Error("同一北京日的窗口不应判定为过期")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。
|
||||
func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00
|
||||
want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00
|
||||
if got := nextDailyResetTime(now); !got.Equal(want) {
|
||||
t.Errorf("nextDailyResetTime = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。
|
||||
func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
// 北京 2026-05-26(周二)→ 下周一是 2026-06-01
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二
|
||||
want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location())
|
||||
if got := nextWeeklyResetTime(now); !got.Equal(want) {
|
||||
t.Errorf("nextWeeklyResetTime = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
165
backend/internal/handler/stream_error_event.go
Normal file
165
backend/internal/handler/stream_error_event.go
Normal file
@ -0,0 +1,165 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// responsesFailedError 对齐 OpenAI Responses 协议 error 子对象。
|
||||
type responsesFailedError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// responsesFailedBody 对齐 apicompat.makeResponsesCompletedEvent 输出的 response 子对象字段集。
|
||||
// Output 用空 slice(不是 nil)确保 marshal 为 `[]` 而非 `null`。
|
||||
type responsesFailedBody struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Output []any `json:"output"`
|
||||
Error responsesFailedError `json:"error"`
|
||||
}
|
||||
|
||||
// responsesFailedEvent 是写入 SSE data 行的顶层结构。
|
||||
// 故意不带 sequence_number:spec 标记可选,且本函数被调用时无法可靠拿到 last seq。
|
||||
type responsesFailedEvent struct {
|
||||
Type string `json:"type"`
|
||||
Response responsesFailedBody `json:"response"`
|
||||
}
|
||||
|
||||
// writeResponsesFailedSSE emits a `response.failed` SSE event in the OpenAI
|
||||
// Responses API protocol after the stream has already started.
|
||||
//
|
||||
// 必要性:一旦 SSE 头和任意数据(例如等待槽位时的 ping comment)已经 flush,
|
||||
// HTTP 200 状态码就被固化。此后若网关需要回报错误,只能继续通过 SSE 事件传达。
|
||||
// 通用的 `event: error` 帧不是 Responses 协议规定的终止事件,
|
||||
// Codex CLI 等严格 SDK 会因为没收到 `response.completed/failed/incomplete/cancelled`
|
||||
// 而抛出 "stream closed before response.completed"。
|
||||
//
|
||||
// 字段集对齐 apicompat.makeResponsesCompletedEvent:id/object/model/status/output/error。
|
||||
// 故意不写 sequence_number:本函数被调用时无法可靠拿到当前流的 last sequence,
|
||||
// 而 OpenAI spec 将 sequence_number 设为可选;省略避免破坏单调性约束。
|
||||
//
|
||||
// 返回 true 表示已尝试 SSE 写出(不论 Write 是否成功,caller 都应直接 return)。
|
||||
// 返回 false 表示 writer 不支持 Flusher,无法以 SSE 形式回报错误;
|
||||
// 此时 caller 也无法回退到 JSON(HTTP 200 已固化),通常意味着连接已经损坏,
|
||||
// 应当让请求处理函数 return,由上层关闭连接。
|
||||
func writeResponsesFailedSSE(c *gin.Context, errType, message string) bool {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(responsesFailedEvent{
|
||||
Type: "response.failed",
|
||||
Response: responsesFailedBody{
|
||||
ID: synthesizeResponseID(c),
|
||||
Object: "response",
|
||||
Model: requestModel(c),
|
||||
Status: "failed",
|
||||
Output: []any{},
|
||||
Error: responsesFailedError{
|
||||
Code: mapResponsesErrorCode(errType),
|
||||
Message: message,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return true
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(c.Writer, "event: response.failed\ndata: %s\n\n", payload); err != nil {
|
||||
_ = c.Error(err)
|
||||
return true
|
||||
}
|
||||
flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
// inboundIsResponses 判断当前请求是否落在任何 /responses 路由上。
|
||||
//
|
||||
// 不能直接用 GetInboundEndpoint(c) == EndpointResponses 比较,因为
|
||||
// NormalizeInboundEndpoint 只识别包含 "/v1/responses" 子串的路径;
|
||||
// 项目里实际注册了多组路由(gateway_v1、top-level bare、codex direct),
|
||||
// 其中 r.POST("/responses", ...) 和 codexDirect.POST("/responses", ...)
|
||||
// 的 c.FullPath() 不含 "/v1/" 前缀,会被归一化为原始路径,
|
||||
// 导致协议合规终止事件没法发出去。
|
||||
//
|
||||
// 这里用 FullPath 的后缀判断,覆盖所有变体:
|
||||
// - /v1/responses
|
||||
// - /v1/responses/compact
|
||||
// - /responses
|
||||
// - /responses/compact
|
||||
// - /backend-api/codex/responses
|
||||
// - /backend-api/codex/responses/compact
|
||||
func inboundIsResponses(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
p := strings.TrimRight(c.FullPath(), "/")
|
||||
if p == "" && c.Request != nil && c.Request.URL != nil {
|
||||
p = strings.TrimRight(c.Request.URL.Path, "/")
|
||||
}
|
||||
if p == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(p, "/responses") || strings.Contains(p, "/responses/")
|
||||
}
|
||||
|
||||
// synthesizeResponseID 为合成的 response.failed 事件生成一个稳定的 id。
|
||||
// 优先复用 server 端生成的 request_id(存在 request.Context 里,由 request_logger 写入),
|
||||
// 以便客户端报错能与 server 日志关联;缺失时回退 uuid。
|
||||
func synthesizeResponseID(c *gin.Context) string {
|
||||
if c != nil && c.Request != nil {
|
||||
if rid, ok := c.Request.Context().Value(ctxkey.RequestID).(string); ok {
|
||||
if rid = strings.TrimSpace(rid); rid != "" {
|
||||
return "resp_" + strings.ReplaceAll(rid, "-", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
return "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
|
||||
// requestModel 取当前请求的 inbound model(由 setOpsRequestContext 写入)。
|
||||
// 缺失时返回 "";caller 据此决定是否忽略该字段。
|
||||
func requestModel(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// mapResponsesErrorCode 把内部 errType 映射为 Responses 协议常见的 error.code。
|
||||
// 无明确映射时原样返回,保证至少可读。
|
||||
func mapResponsesErrorCode(errType string) string {
|
||||
switch errType {
|
||||
case "rate_limit_error":
|
||||
return "rate_limit_exceeded"
|
||||
case "invalid_request_error":
|
||||
return "invalid_request"
|
||||
case "permission_error":
|
||||
return "permission_denied"
|
||||
case "authentication_error":
|
||||
return "authentication_failed"
|
||||
case "upstream_error":
|
||||
return "upstream_error"
|
||||
case "server_error", "api_error", "":
|
||||
return "server_error"
|
||||
default:
|
||||
return errType
|
||||
}
|
||||
}
|
||||
253
backend/internal/handler/stream_error_event_test.go
Normal file
253
backend/internal/handler/stream_error_event_test.go
Normal file
@ -0,0 +1,253 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Regression for the production incident on 2026-05-24 around 9:13 CST:
|
||||
// user 16 sent /v1/responses with stream:true via Codex CLI; the user-concurrency
|
||||
// slot wait sent SSE ping comments (flushing HTTP 200 + headers), then the 30s
|
||||
// timeout fired and the handler emitted `event: error\ndata: {...}`. Codex CLI
|
||||
// does not recognize that as a Responses terminal event and reports
|
||||
// "stream closed before response.completed". The fix is to emit a synthetic
|
||||
// response.failed event when the inbound endpoint is /v1/responses.
|
||||
|
||||
func newGinContextForEndpoint(t *testing.T, endpoint string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, endpoint, nil)
|
||||
return c, w
|
||||
}
|
||||
|
||||
// parseResponsesFailedSSE 抽出 SSE 中 data 行的 JSON,返回 (response 对象, error 对象)。
|
||||
func parseResponsesFailedSSE(t *testing.T, body string) (map[string]any, map[string]any) {
|
||||
t.Helper()
|
||||
require.True(t, strings.HasPrefix(body, "event: response.failed\n"),
|
||||
"expect event: response.failed prefix, got: %q", body)
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.SplitN(strings.TrimSuffix(body, "\n\n"), "\n", 2)
|
||||
require.Len(t, lines, 2)
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "))
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data must be valid JSON: %s", jsonStr)
|
||||
|
||||
assert.Equal(t, "response.failed", parsed["type"])
|
||||
// 故意不发 sequence_number,避免与后续真实事件的序号冲突。
|
||||
_, hasSeq := parsed["sequence_number"]
|
||||
assert.False(t, hasSeq, "synthetic event must not emit sequence_number")
|
||||
|
||||
resp, ok := parsed["response"].(map[string]any)
|
||||
require.True(t, ok, "response object missing")
|
||||
assert.Equal(t, "response", resp["object"])
|
||||
assert.Equal(t, "failed", resp["status"])
|
||||
|
||||
errObj, ok := resp["error"].(map[string]any)
|
||||
require.True(t, ok, "error object missing")
|
||||
|
||||
return resp, errObj
|
||||
}
|
||||
|
||||
// OpenAI handler: /v1/responses streaming, after stream started, must emit response.failed.
|
||||
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
"Concurrency limit exceeded for user, please retry later", true)
|
||||
|
||||
resp, errObj := parseResponsesFailedSSE(t, w.Body.String())
|
||||
|
||||
id, _ := resp["id"].(string)
|
||||
assert.True(t, strings.HasPrefix(id, "resp_"), "id should start with resp_, got %q", id)
|
||||
assert.Equal(t, "rate_limit_exceeded", errObj["code"])
|
||||
assert.Equal(t, "Concurrency limit exceeded for user, please retry later", errObj["message"])
|
||||
}
|
||||
|
||||
// 当 setOpsRequestContext 写过 model,合成事件应回填该字段(与 codebase 已有 makeResponsesCompletedEvent 对齐)。
|
||||
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingIncludesModel(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
setOpsRequestContext(c, "gpt-5.5", true)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
|
||||
|
||||
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
|
||||
assert.Equal(t, "gpt-5.5", resp["model"])
|
||||
}
|
||||
|
||||
// 没有 model 时 model 字段不应出现(避免发空字符串污染下游解析)。
|
||||
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingOmitsEmptyModel(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
|
||||
|
||||
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
|
||||
_, hasModel := resp["model"]
|
||||
assert.False(t, hasModel, "model field must be omitted when unknown")
|
||||
}
|
||||
|
||||
// 当 request.Context 携带 ctxkey.RequestID 时,合成 id 应与之关联,便于和 server log 串起来。
|
||||
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingReusesRequestID(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
c.Request = c.Request.WithContext(
|
||||
context.WithValue(c.Request.Context(), ctxkey.RequestID, "fd277bc5-ff7e-45d1-8aa9-f54e1df318f1"),
|
||||
)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "x", true)
|
||||
|
||||
resp, _ := parseResponsesFailedSSE(t, w.Body.String())
|
||||
assert.Equal(t, "resp_fd277bc5ff7e45d18aa9f54e1df318f1", resp["id"])
|
||||
}
|
||||
|
||||
// 与旧分支的 TestOpenAIHandleStreamingAwareError_JSONEscaping 对齐:
|
||||
// 新的 response.failed payload 也必须正确转义 message 里的特殊字符,
|
||||
// 否则下游 SDK 解析 JSON 时会失败。
|
||||
func TestOpenAIHandleStreamingAwareError_ResponsesStreamingJSONEscaping(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{"双引号", "server_error", `upstream returned "invalid" response`},
|
||||
{"反斜杠", "server_error", `path C:\Users\test\file.txt not found`},
|
||||
{"双引号+反斜杠", "upstream_error", `error parsing "key\value": unexpected token`},
|
||||
{"换行与制表", "server_error", "line1\nline2\ttab"},
|
||||
{"普通", "upstream_error", "Upstream service temporarily unavailable"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tc.errType, tc.message, true)
|
||||
|
||||
_, errObj := parseResponsesFailedSSE(t, w.Body.String())
|
||||
assert.Equal(t, tc.message, errObj["message"], "message 必须被原样还原")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI handler: /v1/chat/completions streaming keeps the legacy event: error format
|
||||
// (out of scope for this fix; covered to prevent regression of unrelated paths).
|
||||
func TestOpenAIHandleStreamingAwareError_ChatCompletionsStreamingKeepsLegacy(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointChatCompletions)
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
|
||||
|
||||
body := w.Body.String()
|
||||
assert.True(t, strings.HasPrefix(body, "event: error\n"), "got: %q", body)
|
||||
}
|
||||
|
||||
// Gateway (Anthropic-backed) handler: /v1/responses path also must emit response.failed.
|
||||
func TestGatewayHandleStreamingAwareError_ResponsesStreamingEmitsResponseFailed(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointResponses)
|
||||
h := &GatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "upstream gone", true)
|
||||
|
||||
_, errObj := parseResponsesFailedSSE(t, w.Body.String())
|
||||
assert.Equal(t, "upstream_error", errObj["code"])
|
||||
assert.Equal(t, "upstream gone", errObj["message"])
|
||||
}
|
||||
|
||||
// Gateway handler: /v1/messages preserves the legacy data:{type:error,...} format
|
||||
// (Anthropic spec accepts a type:"error" stream event).
|
||||
func TestGatewayHandleStreamingAwareError_MessagesStreamingKeepsLegacy(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, EndpointMessages)
|
||||
h := &GatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "boom", true)
|
||||
|
||||
body := w.Body.String()
|
||||
assert.True(t, strings.HasPrefix(body, `data: {"type":"error"`), "got: %q", body)
|
||||
}
|
||||
|
||||
// 项目里 /responses 注册在多组路由:/v1/responses(gateway)、裸 /responses(top-level)、
|
||||
// /backend-api/codex/responses(codex direct)。我们 fix 必须覆盖全部,
|
||||
// 否则一些客户端走的路径就不会发 response.failed,照样报 stream closed。
|
||||
// 这是生产 2026-05-24 ~11:05 UTC user 16 实际命中的 bug。
|
||||
func TestInboundIsResponses_CoversAllRoutes(t *testing.T) {
|
||||
cases := []struct {
|
||||
route string
|
||||
want bool
|
||||
}{
|
||||
{"/v1/responses", true},
|
||||
{"/v1/responses/compact", true},
|
||||
{"/responses", true}, // <-- 用户 16 实际走这条
|
||||
{"/responses/compact", true},
|
||||
{"/backend-api/codex/responses", true},
|
||||
{"/backend-api/codex/responses/compact", true},
|
||||
{"/v1/chat/completions", false},
|
||||
{"/v1/messages", false},
|
||||
{"/", false},
|
||||
{"/responses-fake", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.route, func(t *testing.T) {
|
||||
c, _ := newGinContextForEndpoint(t, tc.route)
|
||||
assert.Equal(t, tc.want, inboundIsResponses(c), "route=%q", tc.route)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 用 c.Request.URL.Path 作为 fallback(当 c.FullPath() 为空时,例如某些测试 fixture)。
|
||||
func TestInboundIsResponses_FallsBackToURLPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses", nil)
|
||||
// 这种情况下 c.FullPath() 是 "",必须 fallback 到 URL.Path
|
||||
assert.True(t, inboundIsResponses(c), "URL.Path fallback must work when FullPath is empty")
|
||||
}
|
||||
|
||||
// 回归生产事故:用户 16 走 /responses 路径,必须发 response.failed。
|
||||
func TestOpenAIHandleStreamingAwareError_BareResponsesRouteEmitsResponseFailed(t *testing.T) {
|
||||
c, w := newGinContextForEndpoint(t, "/responses")
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
"Concurrency limit exceeded for user, please retry later", true)
|
||||
|
||||
resp, errObj := parseResponsesFailedSSE(t, w.Body.String())
|
||||
id, _ := resp["id"].(string)
|
||||
assert.True(t, strings.HasPrefix(id, "resp_"))
|
||||
assert.Equal(t, "rate_limit_exceeded", errObj["code"])
|
||||
}
|
||||
|
||||
// Synthesized response.failed id falls back to uuid when no request_id is present.
|
||||
func TestSynthesizeResponseID_FallbackUUID(t *testing.T) {
|
||||
c, _ := newGinContextForEndpoint(t, EndpointResponses)
|
||||
id := synthesizeResponseID(c)
|
||||
assert.True(t, strings.HasPrefix(id, "resp_"))
|
||||
// uuid 去掉短横线后 32 hex 字符;前缀 "resp_" 共 37。
|
||||
assert.Len(t, id, 37)
|
||||
}
|
||||
|
||||
func TestMapResponsesErrorCode(t *testing.T) {
|
||||
cases := []struct{ in, out string }{
|
||||
{"rate_limit_error", "rate_limit_exceeded"},
|
||||
{"invalid_request_error", "invalid_request"},
|
||||
{"permission_error", "permission_denied"},
|
||||
{"authentication_error", "authentication_failed"},
|
||||
{"upstream_error", "upstream_error"},
|
||||
{"server_error", "server_error"},
|
||||
{"api_error", "server_error"},
|
||||
{"", "server_error"},
|
||||
{"custom_thing", "custom_thing"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
assert.Equal(t, tc.out, mapResponsesErrorCode(tc.in), "in=%q", tc.in)
|
||||
}
|
||||
}
|
||||
@ -3,8 +3,10 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -14,11 +16,12 @@ import (
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
affiliateService *service.AffiliateService
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
affiliateService *service.AffiliateService
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
@ -28,16 +31,44 @@ func NewUserHandler(
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
affiliateService *service.AffiliateService,
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMyPlatformQuotas GET /user/platform-quotas
|
||||
// 返回当前 JWT 用户的 platform quota 状态。
|
||||
// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null(不写 DB)
|
||||
func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Success(c, map[string]any{"platform_quotas": []any{}})
|
||||
return
|
||||
}
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
|
||||
@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
|
||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -285,7 +289,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
212
backend/internal/handler/user_platform_quotas_handler_test.go
Normal file
212
backend/internal/handler/user_platform_quotas_handler_test.go
Normal file
@ -0,0 +1,212 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集
|
||||
type fakeQuotaRepoForUserHandler struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
records []service.UserPlatformQuotaRecord
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return f.records, nil
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) {
|
||||
repo := &fakeQuotaRepoForUserHandler{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
PlatformQuotas []any `json:"platform_quotas"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String())
|
||||
}
|
||||
if body.Code != 0 {
|
||||
t.Errorf("expected code=0, got %d", body.Code)
|
||||
}
|
||||
if body.Data.PlatformQuotas == nil {
|
||||
// nil 和 empty slice 均视为可接受(JSON 可能序列化为 null 或 [])
|
||||
// 此断言只验证 HTTP 200 + code=0 即可
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) {
|
||||
pastStart := time.Now().UTC().AddDate(0, 0, -2)
|
||||
daily := 5.0
|
||||
repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{
|
||||
UserID: 42,
|
||||
Platform: "anthropic",
|
||||
DailyLimitUSD: &daily,
|
||||
DailyUsageUSD: 3.0,
|
||||
DailyWindowStart: &pastStart,
|
||||
}}}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 解析 response,验证过期 daily 的 usage_usd=0 且 window_resets_at=null
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"daily_usage_usd":0`) {
|
||||
t.Errorf("expected daily_usage_usd:0 in body, got: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, `"daily_window_resets_at":null`) {
|
||||
t.Errorf("expected daily_window_resets_at:null in body, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
// 不设置 auth subject
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 401 {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-1 * time.Hour)
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "anthropic",
|
||||
DailyUsageUSD: 1.0,
|
||||
DailyWindowStart: &start,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false)
|
||||
if _, ok := out["daily_window_start"]; ok {
|
||||
t.Error("user view should not include daily_window_start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-1 * time.Hour)
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "anthropic",
|
||||
DailyWindowStart: &start,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true)
|
||||
if _, ok := out["daily_window_start"]; !ok {
|
||||
t.Error("admin view should include daily_window_start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) {
|
||||
// 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径
|
||||
now := time.Now()
|
||||
today := timezone.StartOfDay(now)
|
||||
usage := 2.5
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "openai",
|
||||
DailyUsageUSD: usage,
|
||||
DailyWindowStart: &today,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, now, false)
|
||||
if out["daily_usage_usd"] != usage {
|
||||
t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"])
|
||||
}
|
||||
// 活跃窗口应有 resets_at(非 nil)
|
||||
if out["daily_window_resets_at"] == nil {
|
||||
t.Error("active window should have daily_window_resets_at set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsDailyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) {
|
||||
old := time.Now().UTC().AddDate(0, 0, -1)
|
||||
if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) {
|
||||
t.Error("yesterday start should need daily reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need weekly reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need monthly reset")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义(C-NEW-1)。
|
||||
func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期
|
||||
if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
|
||||
t.Error("31 days ago should need monthly reset (30-day rolling)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效
|
||||
if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
|
||||
t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。
|
||||
func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) {
|
||||
// 窗口起始 4 月 20 日;5 月 1 日仅过了 11 天,不足 30 天,不应重置
|
||||
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
|
||||
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
if quotaview.NeedsMonthlyReset(&windowStart, now) {
|
||||
t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)")
|
||||
}
|
||||
}
|
||||
@ -517,6 +517,33 @@ func TestResponsesEventToAnthropicEvents_ResponseDone(t *testing.T) {
|
||||
assert.Nil(t, FinalizeResponsesAnthropicStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_TopLevelTerminalUsage(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 6,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 5,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, events, 2)
|
||||
assert.Equal(t, "message_delta", events[0].Type)
|
||||
require.NotNil(t, events[0].Usage)
|
||||
assert.Equal(t, 15, events[0].Usage.InputTokens)
|
||||
assert.Equal(t, 5, events[0].Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 6, events[0].Usage.OutputTokens)
|
||||
assert.Equal(t, "message_stop", events[1].Type)
|
||||
}
|
||||
|
||||
func TestResponsesEventToAnthropicEvents_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToAnthropicState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
@ -846,6 +846,33 @@ func TestResponsesEventToChatChunks_ResponseDone(t *testing.T) {
|
||||
assert.Nil(t, FinalizeResponsesChatStream(state))
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_TopLevelTerminalUsage(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.IncludeUsage = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.completed",
|
||||
Response: &ResponsesResponse{
|
||||
Status: "completed",
|
||||
},
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: 21,
|
||||
OutputTokens: 9,
|
||||
InputTokensDetails: &ResponsesInputTokensDetails{
|
||||
CachedTokens: 4,
|
||||
},
|
||||
},
|
||||
}, state)
|
||||
|
||||
require.Len(t, chunks, 2)
|
||||
require.NotNil(t, chunks[1].Usage)
|
||||
assert.Equal(t, 21, chunks[1].Usage.PromptTokens)
|
||||
assert.Equal(t, 9, chunks[1].Usage.CompletionTokens)
|
||||
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
|
||||
assert.Equal(t, 4, chunks[1].Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ResponseDoneIncomplete(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
|
||||
@ -567,6 +567,12 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
events = append(events, closeCurrentBlock(state)...)
|
||||
|
||||
stopReason := "end_turn"
|
||||
if evt.Usage != nil {
|
||||
usage := anthropicUsageFromResponsesUsage(evt.Usage)
|
||||
state.InputTokens = usage.InputTokens
|
||||
state.OutputTokens = usage.OutputTokens
|
||||
state.CacheReadInputTokens = usage.CacheReadInputTokens
|
||||
}
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
|
||||
|
||||
@ -293,20 +293,12 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
state.Finalized = true
|
||||
finishReason := "stop"
|
||||
|
||||
if evt.Usage != nil {
|
||||
state.Usage = chatUsageFromResponsesUsage(evt.Usage)
|
||||
}
|
||||
if evt.Response != nil {
|
||||
if evt.Response.Usage != nil {
|
||||
u := evt.Response.Usage
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
state.Usage = usage
|
||||
state.Usage = chatUsageFromResponsesUsage(evt.Response.Usage)
|
||||
}
|
||||
|
||||
switch evt.Response.Status {
|
||||
@ -340,6 +332,23 @@ func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
|
||||
return chunks
|
||||
}
|
||||
|
||||
func chatUsageFromResponsesUsage(u *ResponsesUsage) *ChatUsage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
usage := &ChatUsage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.InputTokens + u.OutputTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails = &ChatTokenDetails{
|
||||
CachedTokens: u.InputTokensDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
|
||||
return ChatCompletionsChunk{
|
||||
ID: state.ID,
|
||||
|
||||
@ -380,6 +380,8 @@ type ResponsesStreamEvent struct {
|
||||
|
||||
// response.created / response.completed / response.done / response.failed / response.incomplete
|
||||
Response *ResponsesResponse `json:"response,omitempty"`
|
||||
// 部分 OpenAI 兼容上游会把 usage 放在终止事件顶层,而不是 response.usage。
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
|
||||
// response.output_item.added / response.output_item.done
|
||||
Item *ResponsesOutput `json:"item,omitempty"`
|
||||
|
||||
@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) {
|
||||
_ = Now()
|
||||
_ = StartOfDay(Now())
|
||||
}
|
||||
|
||||
func TestStartOfWeek_Boundaries(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = Init("UTC") })
|
||||
|
||||
loc := Location()
|
||||
wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
in time.Time
|
||||
}{
|
||||
{"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)},
|
||||
{"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)},
|
||||
{"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
if got := StartOfWeek(c.in); !got.Equal(wantMon) {
|
||||
t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1085,7 +1085,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time, reason ...string) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
}
|
||||
@ -1094,6 +1094,11 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
if len(reason) > 0 {
|
||||
if value := strings.TrimSpace(reason[0]); value != "" {
|
||||
payload["reason"] = value
|
||||
}
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1129,6 +1134,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -183,6 +183,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
group.FieldMessagesDispatchModelConfig,
|
||||
group.FieldModelsListConfig,
|
||||
group.FieldRpmLimit,
|
||||
)
|
||||
}).
|
||||
@ -723,6 +724,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
|
||||
ModelsListConfig: g.ModelsListConfig,
|
||||
RPMLimit: g.RpmLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
|
||||
@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6
|
||||
key := billingRateLimitKey(keyID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// user × platform quota 缓存
|
||||
// ============================================
|
||||
|
||||
// userPlatformQuotaCacheKey 构造 Redis key
|
||||
func userPlatformQuotaCacheKey(userID int64, platform string) string {
|
||||
return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform)
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) {
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
fields := []string{
|
||||
"daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version",
|
||||
"daily_limit", "weekly_limit", "monthly_limit",
|
||||
"daily_window_start", "weekly_window_start", "monthly_window_start",
|
||||
}
|
||||
vals, err := c.rdb.HMGet(ctx, key, fields...).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
// 前4个全为nil → key 不存在
|
||||
if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
parseFloat := func(v any) float64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
f, _ := strconv.ParseFloat(s, 64)
|
||||
return f
|
||||
}
|
||||
parseFloatPtr := func(v any) *float64 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return nil
|
||||
}
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &f
|
||||
}
|
||||
parseTimePtr := func(v any) *time.Time {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return nil
|
||||
}
|
||||
n, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
t := time.Unix(n, 0).UTC()
|
||||
return &t
|
||||
}
|
||||
parseInt64 := func(v any) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
n, _ := strconv.ParseInt(s, 10, 64)
|
||||
return n
|
||||
}
|
||||
return &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: parseFloat(vals[0]),
|
||||
WeeklyUsageUSD: parseFloat(vals[1]),
|
||||
MonthlyUsageUSD: parseFloat(vals[2]),
|
||||
Version: parseInt64(vals[3]),
|
||||
SchemaVersion: parseInt64(vals[4]),
|
||||
DailyLimitUSD: parseFloatPtr(vals[5]),
|
||||
WeeklyLimitUSD: parseFloatPtr(vals[6]),
|
||||
MonthlyLimitUSD: parseFloatPtr(vals[7]),
|
||||
DailyWindowStart: parseTimePtr(vals[8]),
|
||||
WeeklyWindowStart: parseTimePtr(vals[9]),
|
||||
MonthlyWindowStart: parseTimePtr(vals[10]),
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
pipe := c.rdb.TxPipeline()
|
||||
|
||||
// 浮点可空字段:nil → 空字符串(读取时 parseFloatPtr 返回 nil,表示无限额)
|
||||
fmtFloatPtr := func(p *float64) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatFloat(*p, 'f', -1, 64)
|
||||
}
|
||||
// time.Time 可空字段:nil → 空字符串;有值 → unix 秒
|
||||
fmtTimePtr := func(p *time.Time) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatInt(p.Unix(), 10)
|
||||
}
|
||||
|
||||
pipe.HSet(ctx, key,
|
||||
"daily_usage", entry.DailyUsageUSD,
|
||||
"weekly_usage", entry.WeeklyUsageUSD,
|
||||
"monthly_usage", entry.MonthlyUsageUSD,
|
||||
"version", entry.Version,
|
||||
"schema_version", entry.SchemaVersion,
|
||||
"daily_limit", fmtFloatPtr(entry.DailyLimitUSD),
|
||||
"weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD),
|
||||
"monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD),
|
||||
"daily_window_start", fmtTimePtr(entry.DailyWindowStart),
|
||||
"weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart),
|
||||
"monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart),
|
||||
)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
|
||||
return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err()
|
||||
}
|
||||
|
||||
// updateUserPlatformQuotaUsageScript 缓存累加:EXISTS + schema_version 双重守卫。
|
||||
// 旧版 entry(schema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后
|
||||
// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。
|
||||
// key 不存在同样跳过(由下次 SetCache 重建)。
|
||||
// KEYS[1] = hash key
|
||||
// ARGV[1] = cost (string float)
|
||||
// ARGV[2] = ttl seconds
|
||||
// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1)
|
||||
const updateUserPlatformQuotaUsageScript = `
|
||||
if redis.call("EXISTS", KEYS[1]) == 0 then
|
||||
return 0
|
||||
end
|
||||
local ver = redis.call("HGET", KEYS[1], "schema_version")
|
||||
if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then
|
||||
return 0
|
||||
end
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1])
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1])
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1])
|
||||
redis.call("HINCRBY", KEYS[1], "version", 1)
|
||||
redis.call("EXPIRE", KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`
|
||||
|
||||
func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
_, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key},
|
||||
strconv.FormatFloat(cost, 'f', -1, 64),
|
||||
int(ttl.Seconds()),
|
||||
service.UserPlatformQuotaCacheSchemaV1,
|
||||
).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) {
|
||||
t.Helper()
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
return &billingCache{rdb: rdb}, mr
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok || entry != nil {
|
||||
t.Errorf("expected miss, got ok=%v entry=%v", ok, entry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
dailyLimit := 20.0
|
||||
ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
in := &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 1.5,
|
||||
WeeklyUsageUSD: 3.0,
|
||||
MonthlyUsageUSD: 10.0,
|
||||
Version: 7,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
DailyLimitUSD: &dailyLimit,
|
||||
DailyWindowStart: &ts,
|
||||
}
|
||||
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("get: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 {
|
||||
t.Errorf("got = %+v, want %+v", got, in)
|
||||
}
|
||||
if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1)
|
||||
}
|
||||
if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit {
|
||||
t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit)
|
||||
}
|
||||
if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) {
|
||||
t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
in := &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 1.0,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
// DailyLimitUSD nil → 无限额
|
||||
}
|
||||
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("get: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if got.DailyLimitUSD != nil {
|
||||
t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai")
|
||||
if ok {
|
||||
t.Error("expected key absent after no-op incr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
// SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。
|
||||
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{
|
||||
Version: 1,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
}, time.Minute)
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 {
|
||||
t.Errorf("got %+v, want daily/weekly/monthly=0.75", got)
|
||||
}
|
||||
if got.Version != 3 {
|
||||
t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_Delete(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute)
|
||||
if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if ok {
|
||||
t.Error("expected miss after delete")
|
||||
}
|
||||
}
|
||||
@ -192,6 +192,7 @@ SELECT COUNT(*)
|
||||
FROM content_moderation_logs
|
||||
WHERE user_id = $1
|
||||
AND flagged = TRUE
|
||||
AND action <> 'hash_block'
|
||||
AND created_at >= $2
|
||||
AND created_at > COALESCE((SELECT at FROM last_auto_ban), '-infinity'::timestamptz)
|
||||
`, userID, since).Scan(&count)
|
||||
@ -246,7 +247,7 @@ func buildContentModerationLogWhere(filter service.ContentModerationLogFilter) (
|
||||
case "hit", "flagged":
|
||||
where = append(where, "l.flagged = TRUE")
|
||||
case "blocked", "block":
|
||||
where = append(where, "l.action = 'block'")
|
||||
where = append(where, "l.action IN ('block', 'keyword_block', 'hash_block')")
|
||||
case "pass", "allow":
|
||||
where = append(where, "l.flagged = FALSE AND l.error = ''")
|
||||
case "error":
|
||||
|
||||
40
backend/internal/repository/content_moderation_repo_test.go
Normal file
40
backend/internal/repository/content_moderation_repo_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildContentModerationLogWhere_BlockedIncludesAllBlockActions(t *testing.T) {
|
||||
where, args := buildContentModerationLogWhere(service.ContentModerationLogFilter{Result: "blocked"})
|
||||
|
||||
require.Empty(t, args)
|
||||
sql := strings.Join(where, " AND ")
|
||||
require.Contains(t, sql, "l.action IN ('block', 'keyword_block', 'hash_block')")
|
||||
require.NotContains(t, sql, "l.action = 'block'")
|
||||
}
|
||||
|
||||
func TestContentModerationRepositoryCountFlaggedByUserSince_ExcludesHashBlock(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
repo := NewContentModerationRepository(db)
|
||||
since := time.Now().Add(-time.Hour)
|
||||
mock.ExpectQuery(regexp.QuoteMeta("AND action <> 'hash_block'")).
|
||||
WithArgs(int64(1001), since).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(2))
|
||||
|
||||
count, err := repo.CountFlaggedByUserSince(context.Background(), 1001, since)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
@ -66,6 +66,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetModelsListConfig(groupIn.ModelsListConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 设置模型路由配置
|
||||
@ -141,6 +142,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
||||
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
|
||||
SetModelsListConfig(groupIn.ModelsListConfig).
|
||||
SetRpmLimit(groupIn.RPMLimit)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
|
||||
@ -3,6 +3,8 @@ package repository
|
||||
import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -50,6 +52,17 @@ const (
|
||||
defaultMaxUpstreamClients = 5000
|
||||
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
|
||||
defaultClientIdleTTLSeconds = 900
|
||||
// OpenAI HTTP/2 代理回退策略默认值
|
||||
defaultOpenAIHTTP2FallbackErrorThreshold = 2
|
||||
defaultOpenAIHTTP2FallbackWindow = 60 * time.Second
|
||||
defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
upstreamProtocolModeDefault = "default"
|
||||
upstreamProtocolModeOpenAIH1 = "openai_h1"
|
||||
upstreamProtocolModeOpenAIH2 = "openai_h2"
|
||||
upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback"
|
||||
)
|
||||
|
||||
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
|
||||
@ -64,14 +77,30 @@ type poolSettings struct {
|
||||
responseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
}
|
||||
|
||||
type openAIHTTP2Settings struct {
|
||||
enabled bool
|
||||
allowProxyFallbackToHTTP1 bool
|
||||
fallbackErrorThreshold int
|
||||
fallbackWindow time.Duration
|
||||
fallbackTTL time.Duration
|
||||
}
|
||||
|
||||
// upstreamClientEntry 上游客户端缓存条目
|
||||
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
|
||||
type upstreamClientEntry struct {
|
||||
client *http.Client // HTTP 客户端实例
|
||||
proxyKey string // 代理标识(用于检测代理变更)
|
||||
poolKey string // 连接池配置标识(用于检测配置变更)
|
||||
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
|
||||
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
|
||||
client *http.Client // HTTP 客户端实例
|
||||
proxyKey string // 代理标识(用于检测代理变更)
|
||||
poolKey string // 连接池配置标识(用于检测配置变更)
|
||||
protocolMode string // 协议模式(default/openai_h1/openai_h2/openai_h1_fallback)
|
||||
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
|
||||
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
|
||||
}
|
||||
|
||||
type openAIHTTP2FallbackState struct {
|
||||
mu sync.Mutex
|
||||
windowStart time.Time
|
||||
errorCount int
|
||||
fallbackUntil time.Time
|
||||
}
|
||||
|
||||
// httpUpstreamService 通用 HTTP 上游服务
|
||||
@ -95,6 +124,8 @@ type httpUpstreamService struct {
|
||||
cfg *config.Config // 全局配置
|
||||
mu sync.RWMutex // 保护 clients map 的读写锁
|
||||
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
|
||||
// OpenAI 走 HTTP/HTTPS 代理时的 H2->H1 回退状态(key=标准化 proxyKey)
|
||||
openAIHTTP2Fallbacks sync.Map
|
||||
}
|
||||
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
@ -142,9 +173,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile := service.HTTPUpstreamProfileDefault
|
||||
if req != nil {
|
||||
profile = service.HTTPUpstreamProfileFromContext(req.Context())
|
||||
}
|
||||
|
||||
// 获取或创建对应的客户端,并标记请求占用
|
||||
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
||||
entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -152,11 +187,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err)
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
return nil, err
|
||||
}
|
||||
s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey)
|
||||
|
||||
// 如果上游返回了压缩内容,解压后再交给业务层
|
||||
decompressResponseBody(resp)
|
||||
@ -179,6 +216,10 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
|
||||
if profile == nil {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
upstreamProfile := service.HTTPUpstreamProfileDefault
|
||||
if req != nil {
|
||||
upstreamProfile = service.HTTPUpstreamProfileFromContext(req.Context())
|
||||
}
|
||||
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
@ -194,7 +235,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
@ -219,21 +260,23 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, upstreamProfile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, upstreamProfile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
settings = s.applyProfilePoolSettings(settings, upstreamProfile)
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault)
|
||||
poolKey := buildPoolKey(settings, upstreamProtocolModeDefault) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
@ -284,7 +327,6 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", logredact.RedactProxyURL(proxyKey))
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
@ -350,7 +392,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req
|
||||
// acquireClient 获取或创建客户端,并标记为进行中请求
|
||||
// 用于请求路径,避免在获取后被淘汰
|
||||
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
|
||||
return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault)
|
||||
}
|
||||
|
||||
// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。
|
||||
func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getOrCreateClient 获取或创建客户端
|
||||
@ -369,13 +416,13 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
|
||||
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
|
||||
// - account_proxy: 按账户+代理组合隔离,最细粒度
|
||||
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false)
|
||||
}
|
||||
|
||||
// getClientEntry 获取或创建客户端条目
|
||||
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
|
||||
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
|
||||
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
// 获取隔离模式
|
||||
isolation := s.getIsolationMode()
|
||||
// 标准化代理 URL 并解析
|
||||
@ -383,10 +430,14 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 根据请求 profile(例如 OpenAI)选择协议模式
|
||||
protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
settings = s.applyProfilePoolSettings(settings, profile)
|
||||
// 构建缓存键(根据隔离策略不同)
|
||||
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
|
||||
cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode)
|
||||
// 构建连接池配置键(用于检测配置变更)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency)
|
||||
poolKey := buildPoolKey(settings, protocolMode)
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
@ -429,8 +480,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
||||
}
|
||||
|
||||
// 缓存未命中或需要重建,创建新客户端
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build transport: %w", err)
|
||||
@ -440,9 +490,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
protocolMode: protocolMode,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
@ -626,22 +677,31 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu
|
||||
return settings
|
||||
}
|
||||
|
||||
// buildPoolKey 构建连接池配置键
|
||||
// 用于检测配置变更,配置变更时需要重建客户端
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - string: 配置键
|
||||
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
|
||||
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
|
||||
if accountConcurrency > 0 {
|
||||
return fmt.Sprintf("account:%d", accountConcurrency)
|
||||
}
|
||||
func (s *httpUpstreamService) applyProfilePoolSettings(settings poolSettings, profile service.HTTPUpstreamProfile) poolSettings {
|
||||
if profile != service.HTTPUpstreamProfileOpenAI {
|
||||
return settings
|
||||
}
|
||||
return "default"
|
||||
settings.responseHeaderTimeout = 0
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIResponseHeaderTimeout > 0 {
|
||||
settings.responseHeaderTimeout = time.Duration(s.cfg.Gateway.OpenAIResponseHeaderTimeout) * time.Second
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// buildPoolKey 构建连接池配置键,用于检测连接池配置变更。
|
||||
func buildPoolKey(settings poolSettings, protocolMode string) string {
|
||||
base := fmt.Sprintf(
|
||||
"idle:%d|idle_host:%d|max:%d|idle_timeout:%s|header_timeout:%s",
|
||||
settings.maxIdleConns,
|
||||
settings.maxIdleConnsPerHost,
|
||||
settings.maxConnsPerHost,
|
||||
settings.idleConnTimeout,
|
||||
settings.responseHeaderTimeout,
|
||||
)
|
||||
if protocolMode == "" || protocolMode == upstreamProtocolModeDefault {
|
||||
return base
|
||||
}
|
||||
return base + "|proto:" + protocolMode
|
||||
}
|
||||
|
||||
// buildCacheKey 构建客户端缓存键
|
||||
@ -659,15 +719,245 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency
|
||||
// - proxy 模式: "proxy:{proxyKey}"
|
||||
// - account 模式: "account:{accountID}"
|
||||
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
|
||||
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
|
||||
func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string {
|
||||
var base string
|
||||
switch isolation {
|
||||
case config.ConnectionPoolIsolationAccount:
|
||||
return fmt.Sprintf("account:%d", accountID)
|
||||
base = fmt.Sprintf("account:%d", accountID)
|
||||
case config.ConnectionPoolIsolationAccountProxy:
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
|
||||
base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
|
||||
default:
|
||||
return fmt.Sprintf("proxy:%s", proxyKey)
|
||||
base = fmt.Sprintf("proxy:%s", proxyKey)
|
||||
}
|
||||
if protocolMode != "" && protocolMode != upstreamProtocolModeDefault {
|
||||
base += "|proto:" + protocolMode
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings {
|
||||
settings := openAIHTTP2Settings{
|
||||
enabled: false,
|
||||
allowProxyFallbackToHTTP1: true,
|
||||
fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold,
|
||||
fallbackWindow: defaultOpenAIHTTP2FallbackWindow,
|
||||
fallbackTTL: defaultOpenAIHTTP2FallbackTTL,
|
||||
}
|
||||
if s == nil || s.cfg == nil {
|
||||
return settings
|
||||
}
|
||||
cfg := s.cfg.Gateway.OpenAIHTTP2
|
||||
settings.enabled = cfg.Enabled
|
||||
settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1
|
||||
if cfg.FallbackErrorThreshold > 0 {
|
||||
settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold
|
||||
}
|
||||
if cfg.FallbackWindowSeconds > 0 {
|
||||
settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second
|
||||
}
|
||||
if cfg.FallbackTTLSeconds > 0 {
|
||||
settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string {
|
||||
if profile != service.HTTPUpstreamProfileOpenAI {
|
||||
return upstreamProtocolModeDefault
|
||||
}
|
||||
settings := s.resolveOpenAIHTTP2Settings()
|
||||
if !settings.enabled {
|
||||
return upstreamProtocolModeOpenAIH1
|
||||
}
|
||||
if parsedProxy == nil {
|
||||
return upstreamProtocolModeOpenAIH2
|
||||
}
|
||||
scheme := strings.ToLower(parsedProxy.Scheme)
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return upstreamProtocolModeOpenAIH2
|
||||
}
|
||||
if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) {
|
||||
return upstreamProtocolModeOpenAIH1Fallback
|
||||
}
|
||||
return upstreamProtocolModeOpenAIH2
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool {
|
||||
raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
state, ok := raw.(*openAIHTTP2FallbackState)
|
||||
if !ok || state == nil {
|
||||
return false
|
||||
}
|
||||
return state.isFallbackActive(time.Now())
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState {
|
||||
state := &openAIHTTP2FallbackState{}
|
||||
actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state)
|
||||
cached, ok := actual.(*openAIHTTP2FallbackState)
|
||||
if !ok || cached == nil {
|
||||
return state
|
||||
}
|
||||
return cached
|
||||
}
|
||||
|
||||
func isHTTPProxyKey(proxyKey string) bool {
|
||||
return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://")
|
||||
}
|
||||
|
||||
func isOpenAIHTTP2CompatibilityError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if isUpstreamTimeoutError(err) {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
markers := []string{
|
||||
"alpn",
|
||||
"no application protocol",
|
||||
"protocol error",
|
||||
"stream error",
|
||||
"goaway",
|
||||
"refused_stream",
|
||||
"frame too large",
|
||||
}
|
||||
for _, marker := range markers {
|
||||
if strings.Contains(msg, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isUpstreamTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
timeoutMarkers := []string{
|
||||
"timeout awaiting response headers",
|
||||
"i/o timeout",
|
||||
"context deadline exceeded",
|
||||
"client.timeout exceeded while awaiting headers",
|
||||
"tls handshake timeout",
|
||||
}
|
||||
for _, marker := range timeoutMarkers {
|
||||
if strings.Contains(msg, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) {
|
||||
if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
|
||||
return
|
||||
}
|
||||
settings := s.resolveOpenAIHTTP2Settings()
|
||||
if !settings.enabled || !settings.allowProxyFallbackToHTTP1 {
|
||||
return
|
||||
}
|
||||
if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) {
|
||||
return
|
||||
}
|
||||
state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
|
||||
activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL)
|
||||
if activated {
|
||||
slog.Warn("openai_http2_proxy_fallback_activated",
|
||||
"proxy", proxyKey,
|
||||
"fallback_until", until.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) {
|
||||
if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
|
||||
return
|
||||
}
|
||||
if !isHTTPProxyKey(proxyKey) {
|
||||
return
|
||||
}
|
||||
raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
state, ok := raw.(*openAIHTTP2FallbackState)
|
||||
if !ok || state == nil {
|
||||
return
|
||||
}
|
||||
state.resetErrorWindow()
|
||||
}
|
||||
|
||||
func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.fallbackUntil.IsZero() {
|
||||
return false
|
||||
}
|
||||
if now.Before(s.fallbackUntil) {
|
||||
return true
|
||||
}
|
||||
s.fallbackUntil = time.Time{}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *openAIHTTP2FallbackState) resetErrorWindow() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.windowStart = time.Time{}
|
||||
s.errorCount = 0
|
||||
}
|
||||
|
||||
func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) {
|
||||
if threshold <= 0 {
|
||||
threshold = defaultOpenAIHTTP2FallbackErrorThreshold
|
||||
}
|
||||
if window <= 0 {
|
||||
window = defaultOpenAIHTTP2FallbackWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = defaultOpenAIHTTP2FallbackTTL
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) {
|
||||
return false, s.fallbackUntil
|
||||
}
|
||||
if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) {
|
||||
s.fallbackUntil = time.Time{}
|
||||
}
|
||||
|
||||
if s.windowStart.IsZero() || now.Sub(s.windowStart) > window {
|
||||
s.windowStart = now
|
||||
s.errorCount = 0
|
||||
}
|
||||
s.errorCount++
|
||||
if s.errorCount < threshold {
|
||||
return false, time.Time{}
|
||||
}
|
||||
|
||||
s.fallbackUntil = now.Add(ttl)
|
||||
s.windowStart = time.Time{}
|
||||
s.errorCount = 0
|
||||
return true, s.fallbackUntil
|
||||
}
|
||||
|
||||
// normalizeProxyURL 标准化代理 URL
|
||||
@ -739,7 +1029,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
|
||||
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
|
||||
}
|
||||
if cfg.Gateway.ResponseHeaderTimeout > 0 {
|
||||
if cfg.Gateway.ResponseHeaderTimeout >= 0 {
|
||||
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
}
|
||||
}
|
||||
@ -770,7 +1060,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
||||
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
||||
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
@ -778,6 +1068,17 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
}
|
||||
switch protocolMode {
|
||||
case upstreamProtocolModeOpenAIH2:
|
||||
transport.ForceAttemptHTTP2 = true
|
||||
case upstreamProtocolModeOpenAIH1:
|
||||
transport.ForceAttemptHTTP2 = false
|
||||
transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
|
||||
case upstreamProtocolModeOpenAIH1Fallback:
|
||||
// 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。
|
||||
transport.ForceAttemptHTTP2 = false
|
||||
transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
|
||||
}
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user